레야몬

[C++] 16978번 수열과 쿼리 22 - 자료 구조, 세그먼트 트리, 오프라인 쿼리 본문

알고리즘/백준

[C++] 16978번 수열과 쿼리 22 - 자료 구조, 세그먼트 트리, 오프라인 쿼리

Leyamon 2022. 11. 18. 11:29

1. 문제

  • 길이가 N인 수열 Ai가 주어진다. 이때, 아래 쿼리를 수행하는 프로그램을 작성하시오.
    • 1 i v : Ai = v로 변경
    • k i j : k번째 1번 쿼리가 적용되었을 때 Ai, Ai+1, ..., Aj의 합을 출력

<입력>

  • -1-   수열의 크기 \(N(1 \leq N \leq 100,000)\)
  • -2-   A1, A2, ..., AN \((1 \leq A_{i} \leq 1,000,000)\)
  • -3-   쿼리의 개수 \(M(1 \leq M \leq 100,000)\)
  • -M줄-   쿼리가 한 줄에 하나씩 주어짐 \((1 \leq i \leq N, 1 \leq v \leq 1,000,000, 1 \leq i \leq j \leq N, 0 \leq k \leq (쿼리가 주어진 시점까지 있었던 1번 쿼리의 수)\)

<출력>

  • 모든 2번 쿼리마다 합을 출력한다.

2. 재정의

  • X

3. 해결방법

  • tr에 세그먼트 트리 형성
  • 쿼리에 1번 쿼리가 나온 순서, k가 작을수록 쿼리를 정렬한 후 f=1이 앞에 오도록 정렬.
  • 쿼리를 수행
  • 오프라인 쿼리라서 순서를 바꿔줘서 해결할 수 있다.

4. 실수한 점, 개선할 점

  • 1번의 i가 아닌 1번이 나온 쿼리 순서와 2번의 k가 작을수록 정렬해줘야 한다.
  • 세그먼트 트리가 아닌 세그먼트 트리 파생형 트리가 더 속도가 빠른 것 같다.
  • f == 1이 아닌 f & 1로, no*2이 아닌 no<<1로 하면 소요 시간을 줄일 수 있다.
  • no<<1+1인 no<<(1+1)이다.

 

<코드>

#include <iostream>
#include <algorithm>
#include <vector>

using namespace std;
typedef long long ll;

const int MAX_N = 100001;
const int MAX_M = 100001;

// Query
struct Query {
    int f, k, s, e, idx;
    bool operator < (Query &x) {
        int a, b;
        
        if(f & 1) a = idx;
        else a = k;
        if(x.f & 1) b = x.idx;
        else b = x.k;
        
        if(a!=b)
            return a < b;
        return f < x.f;
    }
};

// 수열의 크기 N과, 수열 A[i], 쿼리의 개수 M
int N, A[MAX_N], M;
// 쿼리 배열
vector<Query> query;
ll res[MAX_M];

// 세그먼트 트리
ll tr[MAX_N*4];
// idx
int cnt1, cnt2=1;

// 세그먼트 트리 구현
ll init(int s, int e, int no) {
    if(s==e)
        return tr[no] = A[s];
    int m = (s+e)/2;
    return tr[no] = init(s, m, no<<1) + init(m+1, e, (no<<1)+1);
}

// 세그먼트 트리 쿼리
ll sum(int s, int e, int no, int l, int r) {
    if(l>e || r<s)
        return 0;
    if(l<=s && e<=r)
        return tr[no];
    ll m = (s+e)/2;
    return sum(s, m, no<<1, l, r) + sum(m+1, e, (no<<1)+1, l, r);
}

// 세그먼트 트리 업데이트
void update(int s, int e, int no, int idx, int va) {
    if(idx<s || idx>e)
        return;
    tr[no] += va;
    if(s == e)
        return;
    ll m = (s+e)/2;
    
    update(s, m, no<<1, idx, va);
    update(m+1, e, (no<<1)+1, idx, va);
}

void input() {
    cin >> N;
    for(int i=1; i<=N; i++)
        cin >> A[i];
    cin >> M;
    
    for(int i=0; i<M; i++) {
        int f, k, s, e;
        cin >> f;
        switch(f) {
            case 1:
                cin >> k >> s;
                query.push_back({1, k, s, 0, cnt2++});
                break;
            case 2:
                cin >> k >> s >> e;
                query.push_back({2, k, s, e, cnt1++});
        }
    }
}

int main() {
    ios_base::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    
    input();
    
    // query를 k순으로 정렬해주기
    sort(query.begin(), query.end());
    // 세그먼트 트리 생성
    init(1, N, 1);
    
    // 쿼리 수행
    for(int i=0; i<M; i++) {
        if(query[i].f == 1) {
            update(1, N, 1, query[i].k, query[i].s - A[query[i].k]);
            A[query[i].k] = query[i].s;
        }
        else
            res[query[i].idx] = sum(1, N, 1, query[i].s, query[i].e);
    }
    
    // 쿼리 결과 출력
    for(int i=0; i<cnt1; i++)
        cout << res[i] << '\n';
    
    return 0;
}

 

 

※현재 고등학교 등교중인 학생입니다. 이제 알고리즘을 본격적으로 공부하기 시작해서 아직 초보입니다. 혹시 제가 잘못 알고 있는 점이나 더 좋은 풀이 방법이 있어 댓글에 남겨주시면 감사히 하나하나 열심히 읽어보겠습니다. 좋아요, 단순한 댓글 한마디라도 저에겐 큰 힘이 됩니다! 감사합니다

Comments