Introduction to Segment Trees

A common problem that can be efficiently solved using a segment tree is the Range Sum Query problem. In this problem, you need to be able to quickly compute the sum of elements in a given range of an array, as well as update elements in the array.

Problem Statement

Given an array of integers, you need to:

  1. Perform a sum query over a range [l, r].
  2. Update the value of an element at a specific index.

Example

Let’s consider an example with an array:

arr = [1, 3, 5, 7, 9, 11]

  • Query: Sum of elements from index 1 to 3 (inclusive), which should return 3 + 5 + 7 = 15.
  • Update: Set the value at index 1 to 10, making the array [1, 10, 5, 7, 9, 11].
  • Query: Now, sum from index 1 to 3 should return 10 + 5 + 7 = 22.

Segment Tree Implementation in C++

Here’s a simple implementation of the segment tree for the Range Sum Query problem:

#include <iostream>
#include <vector>

class SegmentTree {
private:
    std::vector<int> tree;
    std::vector<int> arr;
    int n;

    void buildTree(int node, int start, int end) {
        if (start == end) {
            tree[node] = arr[start];
        } else {
            int mid = (start + end) / 2;
            buildTree(2 * node + 1, start, mid);
            buildTree(2 * node + 2, mid + 1, end);
            tree[node] = tree[2 * node + 1] + tree[2 * node + 2];
        }
    }

    void updateTree(int node, int start, int end, int idx, int value) {
        if (start == end) {
            arr[idx] = value;
            tree[node] = value;
        } else {
            int mid = (start + end) / 2;
            if (idx <= mid) {
                updateTree(2 * node + 1, start, mid, idx, value);
            } else {
                updateTree(2 * node + 2, mid + 1, end, idx, value);
            }
            tree[node] = tree[2 * node + 1] + tree[2 * node + 2];
        }
    }

    int rangeQuery(int node, int start, int end, int l, int r) {
        if (r < start || end < l) {
            return 0; // Out of range
        }
        if (l <= start && end <= r) {
            return tree[node]; // Fully in range
        }
        int mid = (start + end) / 2;
        int leftSum = rangeQuery(2 * node + 1, start, mid, l, r);
        int rightSum = rangeQuery(2 * node + 2, mid + 1, end, l, r);
        return leftSum + rightSum;
    }

public:
    SegmentTree(const std::vector<int>& input) : arr(input) {
        n = input.size();
        tree.resize(4 * n); // Size of segment tree
        buildTree(0, 0, n - 1);
    }

    void update(int idx, int value) {
        updateTree(0, 0, n - 1, idx, value);
    }

    int sumRange(int l, int r) {
        return rangeQuery(0, 0, n - 1, l, r);
    }
};

int main() {
    std::vector<int> input = {1, 3, 5, 7, 9, 11};
    SegmentTree segTree(input);

    std::cout << "Sum of range (1, 3): " << segTree.sumRange(1, 3) << std::endl; // Outputs 15

    segTree.update(1, 10);
    std::cout << "Sum of range (1, 3) after update: " << segTree.sumRange(1, 3) << std::endl; // Outputs 22

    return 0;
}

Explanation

  1. Building the Segment Tree: The buildTree function constructs the segment tree based on the input array.
  2. Updating Values: The updateTree function updates a specific index in the array and adjusts the segment tree accordingly.
  3. Range Query: The rangeQuery function computes the sum for a given range by recursively checking if the current segment is entirely within, out of, or partially overlaps the query range.

This implementation provides efficient O(log n) time complexity for both updates and queries, making it suitable for scenarios where both operations are needed frequently.