레야몬

[C++] 13725번 RNG - 수학, FFT, NTT, 키타마사 본문

알고리즘/백준

[C++] 13725번 RNG - 수학, FFT, NTT, 키타마사

Leyamon 2023. 8. 25. 16:49

13725번 RNG

1. 문제

  • 랜덤 숫자 생성기(RNG)는 아래와 같은 선형 점화식으로 나타낼 수 있다.
  • \(A_i = (A_{i-1} \times C_1 + A_{i-2} \times C_2 + \cdots + A_{i-k} \times C_k)          mod 104857601\)
  • N과 A1, A2, ..., Ak 그리고 C1, C2, ..., Ck가 주어졌을 때, AN을 구하는 프로그램을 작성하시오.

입력

  • -1- : \(k, N(1 \leq k \leq 30,000, 1 \leq N \leq 10^{18})\)
  • -1- : \(A_i, C_i(0 \leq A_i, C_i < 104857601)\)

출력

  • -1- A_N

2. 재정의

  • X

3. 해결 방법

  • 키타마사법에서 다항식의 곱과 나머지 연산을 NTT를 이용해 처리하면 된다.
  • 자세한 해결 방법은 아래 글에 참조된 링크를 들어가면 다른 분들이 자세하게 써주신 글을 확인하면 볼 수 있다. (저는 이를 다룰 자신이 없는데ㅏ....ㅠㅠ)

4. 실수한 점, 개선할 점

  • 다항식의 역원을 구하는 과정에서 다항식의 뺄셈을 하는 데 이 경우 다항식의 차수를 정하는 과정이 어려워 모든 다항식의 자료형의 크기를 최댓값 하나로 고정시켰다.
  • 맨 처음에는 FFT의 연산이 더 빨라서 이것으로 했다가 정확도가 떨어져 다항식의 계수는 정수로 놓고 다항식을 2개로 쪼개서 FFT를 했으나 이것도 정확도가 떨어져서 결국 NTT로 갈아탔다.
  • NTT의 최적화를 해줘야 했다.
    1. 비트를 뒤집는 연산은 다항식의 크기가 정해져 있었기에 미리 캐싱해서 사용하였다.
    2. NTT 과정에서 f(w) = f_even(w^2) + w*f_odd(w^2)를 할 때 배열의 위치를 j+k+i 이런 식으로 했었는데 n이 2^k꼴인 것을 이용하여 | 연산으로 고쳐 개선하였다.
    3. 다항식을 뒤집고, x^k로 모듈러 연산을 해주는 과정을 나는 모두 for문으로 하였는데 다른 사람의 코드를 보니 vector 연산 중 reverse()와 resize 연산으로 해결하는 경우도 있었다. 이 경우가 더 시간 소요가 더 적게 나왔기 때문에 이도 추천한다.

 

FTT, NTT코드는 있어도 13725번 RNG문제의 정답 코드는 인터넷에서 아무리 찾아봐도 찾지 못했기에 그 첫 번째 사람이 되었으면 좋겠다.

 

<코드>

#include <iostream>
#include <vector>

using namespace std;

#define SWAP(type, x, y) do{ type tmp = x; x = y; y = tmp;}while(0)

typedef long long int ll;
typedef vector<int> vi;
typedef vector<ll> vll;

const ll mod = 104857601;
const ll MAX_K = 30000;
const ll MAX_SIZE = 131072;
const ll w = 3;

// <문제>
// 선형점화식의 길이 k, 구하고자 하는 항 N
ll k, N;
// 선형점화식의 계수 A[i]
vll A(MAX_K);
// 선형점화식의 특성방정식, 뒤집은 f, f의 역원
vll f(MAX_SIZE), revf(MAX_SIZE), invf(MAX_SIZE);
// fx = x
vll fx(MAX_SIZE);
vll revbit(MAX_SIZE);

// a ^ b
ll pw(ll a, ll b) {
    ll ret = 1;
    while(b) {
        if(b & 1) ret = ret * a % mod;
        b >>= 1; a = a * a % mod;
    }
    return ret;
}

// Numeric Theoretic Transform
void ntt(vll &f, bool inv = 0) {
    int n=MAX_SIZE, j=0;
    vll root(n >> 1);
    
    for (int i=0; i<n; i++)
        if (i < revbit[i]) 
            SWAP(ll, f[i], f[revbit[i]]);
    
    ll ang = pw(w, (mod-1)/n); if(inv) ang = pw(ang, mod-2);
    root[0] = 1; 
    for(int i=1; i<(n >> 1); i++)
        root[i] = root[i-1] * ang % mod;
    
    for(int i=2; i<=n; i<<=1) {
        int step = n/i;
        for(int j=0; j<n; j+=i) {
            for(int k=0; k<(i >> 1); k++) {
                ll u = f[j | k], v = f[j | k | i >> 1] * root[step*k] % mod;
                f[j | k] = (u + v) % mod;
                f[j | k | i >> 1] = (u - v) % mod;
                if(f[j | k | i >> 1] < 0) f[j | k | i >> 1] += mod;
            }
        }
    }
    if(inv) {
        ll t = pw(n, mod - 2);
        for(int i=0; i<n; i++)
            f[i] = f[i] * t % mod;
    }
}

// 다항식의 곱 반환
void pol_mul(vll &_a, vll _b) {
    int n = MAX_SIZE;
    
    ntt(_a); ntt(_b);
    for(int i=0; i<n; i++) 
        _a[i] = _a[i] * _b[i] % mod;
    ntt(_a, 1);
}

// res를 f로 나눈 나머지 구하기
void pol_div(vll &a) {
    int n = (int)a.size() - 1;
    while(!a[n] && n>0) n--;
    if(n<k) return;
 
    vll ftmp(a);
    for(int i=0; i<(n+1)/2; i++)
        SWAP(ll, ftmp[i], ftmp[n-i]);
    pol_mul(ftmp, invf);
    for(int i=n-k+1; i<MAX_SIZE; i++)
        ftmp[i] = 0;
    for(int i=0; i<(n-k+1)/2; i++)
        SWAP(ll, ftmp[i], ftmp[n-k-i]);
    pol_mul(ftmp, f);
    for(int i=0; i<=n; i++)
        a[i] -= ftmp[i];
}

// 키타마사법
vll Kitamasa(ll n) {
    vll res(MAX_SIZE); res[0] = 1;
    
    while(n) {
        if(n & 1) {
            pol_mul(res, fx);
            pol_div(res);
        }
        if(n > 1) {
            pol_mul(fx, fx);
            pol_div(fx);
        }
        n >>= 1;
    }

    return res;
}

int main() {
    ios_base::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    
    for (int i=1, j=0; i<MAX_SIZE; i++) {
        int bit = MAX_SIZE >> 1;
        while (!((j ^= bit) & bit)) bit >>= 1;
        if (i < j) revbit[i] = j;
    }
    
    fx[1] = 1;
    
    cin >> k >> N;
    for(int i=0; i<k; i++)
        cin >> A[i];
    for(int i=0; i<k; i++) {
        int c; cin >> c;
        f[k-i-1] = -c;
        revf[i+1] = -c;
    } f[k] = 1; revf[0] = 1;
    
    // 함수 revf의 역원 구하기
    // invf =_ invf * (2 - invf * resf) mod x^2k
    int degree = 1;
    invf[0] = 1;
    
    while(degree < k) {
        degree <<= 1;
        vll ftmp(revf);
        pol_mul(ftmp, invf);
        for(int i=degree; i<MAX_SIZE; i++)
            ftmp[i] = 0;
        for(int i=0; i<MAX_SIZE; i++)
            ftmp[i] = -ftmp[i];
        ftmp[0] += 2;
        pol_mul(invf, ftmp);
        for(int i=degree; i<MAX_SIZE; i++)
            invf[i] = 0;
    }
    
    // x^(N-1)을 f로 나눈 나머지 구하기
    vll fres = Kitamasa(N-1);
    
    ll res = 0;
    for(int i=0; i<k; i++) {
        res += fres[i] * A[i] % mod;
        res %= mod;
        if(res < 0) res += mod;
    }
    cout << res;
    
    return 0;
}

 

<문제 바로가기>

https://www.acmicpc.net/problem/13725

 

13725번: RNG

첫째 줄에 k와 N (1 ≤ k ≤ 30,000, 1 ≤ N ≤ 1018)이 주어진다. 둘째 줄에는 A1, A2, ..., Ak가 셋째 줄에는 C1, C2, ..., Ck가 주어진다. (0 ≤ Ai, Ci < 104857601)

www.acmicpc.net

 

<Linear Recurrence>

https://algoshitpo.github.io/2020/05/20/linear-recurrence/

 

Linear Recurrence

이번 글에서는 선형 점화식의 $n$번째 항을 빠르게 구하는 방법에 대해 알아보도록 하겠습니다. 문제 다음과 같이 정의되는 무한수열 ${a_n}$이 있습니다. \[a_n = c_1 a_{n-1} + c_2 a_{n-2} + \ldots + c_k a_{n-

algoshitpo.github.io

 

<정확도 높은 FFT와 NTT>

https://algoshitpo.github.io/2020/05/20/fft-ntt/

 

정확도 높은 FFT와 NTT

FFT에서는 실수 자료형을 사용하기 때문에 실수 오차가 발생할 수 있고, 이는 즐거운 PS생활에 큰 지장을 줄 수 있습니다. 특히 FFT 문제에서 수가 너무 크기 때문에 M으로 나눈 나머지를 출력한다.

algoshitpo.github.io

 

<FFT의 원리 1>

https://justicehui.github.io/hard-algorithm/2019/09/04/FFT/

 

FFT in PS

목차 convolution 다항식의 표현 DFT n-th root of unity DFT와 n-th root of unity FFT IDFT 예제) 큰 수 곱셈

justicehui.github.io

 

<FFT의 원리 2>

https://speakerdeck.com/wookayin/fast-fourier-transform-algorithm?slide=35 

 

Fast Fourier Transform Algorithm

A Introduction to FFT (Fast Fourier Transform) Algorithm with its application in competitive programming. This talk was given in the 2012 SNUPS (Seoul National University Problem Solving Group) Algorithm seminar.

speakerdeck.com

 

<FFT 최적화>

https://tistory.joonhyung.xyz/6

 

Fast Fourier Transform

고속 푸리에 변환(Fast Fourier Transform, FFT)은 convolution을 $O(N\log N)$에 구할 때 활용된다. 이 포스트에서는 코드 자체보다도 FFT 알고리즘의 원리를 알아보는 것이 목적이다. 코드만 보고싶다면 맨 아

tistory.joonhyung.xyz

 

※ 궁금한 것 질문 남겨주시면 답변드리겠습니다. 좋아요 눌러주시고 편하게 질문해 주세요.

Comments