Transform to Sum Tree

Last Updated : 20 May, 2026

Given a Binary Tree, convert it to a tree where each node contains the sum of the left and right sub trees in the original tree. The values of leaf nodes are changed to 0.

Input:

2056957837

Output:

2056957921
Try It Yourself
redirect icon

Using Hash Map or Dictionary - O(n) Time and O(n) Space

First store the sum of each subtree in a HashMap using one traversal, then in a second traversal update each node using the stored values.

  • Create a map<Node*, int> to store subtree sums
  • First traversal: Compute sum of left + right + node and Store in map :
  • For each node: New value = (sum of left subtree + sum of right subtree)
  • Return updated tree.
C++
#include <iostream>
#include <unordered_map>
using namespace std;

// Structure of a tree node
struct Node
{
    int data;
    Node *left;
    Node *right;

    Node(int val)
    {
        data = val;
        left = right = NULL;
    }
};

// Function to store subtree sums
int storeSum(Node *root, unordered_map<Node *, int> &mp)
{
    // Base case
    if (root == NULL)
        return 0;

    // Recursively get sum of left and right subtree
    int leftSum = storeSum(root->left, mp);
    int rightSum = storeSum(root->right, mp);

    // Total sum including current node
    int total = root->data + leftSum + rightSum;

    // Store this sum in map
    mp[root] = total;

    // Return total sum to parent
    return total;
}

// Helper function
void solve(Node *root, unordered_map<Node *, int> &mp)
{
    if (root == NULL)
    {
        return;
    }

    // Get left subtree sum from map
    int leftSum = (root->left ? mp[root->left] : 0);

    // Get right subtree sum from map
    int rightSum = (root->right ? mp[root->right] : 0);

    // Update current node value
    root->data = leftSum + rightSum;

    // Recur for left and right subtree
    solve(root->left, mp);
    solve(root->right, mp);
}

// Function to convert tree into sum tree
void toSumTree(Node *root)
{
    // Base case
    if (root == NULL)
        return;

    // HashMap to store subtree sum of each node
    unordered_map<Node *, int> mp;

    // store subtree sum
    storeSum(root, mp);

    // convert tree to sum tree
    solve(root, mp);
}

// Function to print inorder traversal
void printInorder(Node *root)
{
    if (root == nullptr)
        return;

    printInorder(root->left);
    cout << " " << root->data;
    printInorder(root->right);
}

// Driver code
int main()
{
    // Constructing the tree
    Node *root = new Node(10);
    root->left = new Node(-2);
    root->right = new Node(6);
    root->left->left = new Node(8);
    root->left->right = new Node(-4);
    root->right->left = new Node(7);
    root->right->right = new Node(5);

    toSumTree(root);

    printInorder(root);

    return 0;
}
Java
import java.util.HashMap;

// Structure of a tree node
class Node {

    int data;
    Node left;
    Node right;

    Node(int val)
    {
        data = val;
        left = right = null;
    }
}

class Main {

    // Function to store subtree sums
    static int storeSum(Node root, HashMap<Node, Integer> map)
    {
        // Base case
        if (root == null)
            return 0;

        // Recursively get sum of left and right subtree
        int leftSum = storeSum(root.left, map);
        int rightSum = storeSum(root.right, map);

        // Total sum including current node
        int total = root.data + leftSum + rightSum;

        // Store subtree sum of current node in HashMap
        map.put(root, total);

        // Return subtree sum to parent
        return total;
    }

    // Helper function to convert tree into Sum Tree
    static void solve(Node root, HashMap<Node, Integer> map)
    {
        // Base case
        if (root == null)
            return;

        // Get left subtree sum from HashMap
        int leftSum
            = (root.left != null) ? map.get(root.left) : 0;

        // Get right subtree sum from HashMap
        int rightSum = (root.right != null) ? map.get(root.right) : 0;

        // Update current node value
        root.data = leftSum + rightSum;

        // Recur for left and right subtree
        solve(root.left, map);
        solve(root.right, map);
    }

    // Function to convert tree into Sum Tree
    static void toSumTree(Node root)
    {
        // Base case
        if (root == null)
            return;

        // HashMap to store subtree sum of each node
        HashMap<Node, Integer> map = new HashMap<>();

        // Store subtree sums
        storeSum(root, map);

        // Convert tree into Sum Tree
        solve(root, map);
    }

    // Function to print inorder traversal
    static void printInorder(Node root)
    {
        // Base case
        if (root == null)
            return;

        printInorder(root.left);

        System.out.print(root.data + " ");

        printInorder(root.right);
    }

    // Driver code
    public static void main(String[] args)
    {
        // Constructing the tree
        Node root = new Node(10);

        root.left = new Node(-2);
        root.right = new Node(6);

        root.left.left = new Node(8);
        root.left.right = new Node(-4);

        root.right.left = new Node(7);
        root.right.right = new Node(5);

        // Convert binary tree into Sum Tree
        toSumTree(root);

        // Print inorder traversal
        printInorder(root);
    }
}
Python
# Structure of a tree node
class Node:

    def __init__(self, val):
        self.data = val
        self.left = None
        self.right = None


# Function to store subtree sums
def storeSum(root, mp):

    # Base case
    if root is None:
        return 0

    # Recursively get sum of left and right subtree
    leftSum = storeSum(root.left, mp)
    rightSum = storeSum(root.right, mp)

    # Total sum including current node
    total = root.data + leftSum + rightSum

    # Store subtree sum of current node in dictionary
    mp[root] = total

    # Return subtree sum to parent
    return total


# Helper function to convert tree into Sum Tree
def solve(root, mp):

    # Base case
    if root is None:
        return

    # Get left subtree sum from dictionary
    leftSum = mp[root.left] if root.left else 0

    # Get right subtree sum from dictionary
    rightSum = mp[root.right] if root.right else 0

    # Update current node value
    root.data = leftSum + rightSum

    # Recur for left and right subtree
    solve(root.left, mp)
    solve(root.right, mp)


# Function to convert tree into Sum Tree
def toSumTree(root):

    # Base case
    if root is None:
        return

    # Dictionary to store subtree sums
    mp = {}

    # Store subtree sums
    storeSum(root, mp)

    # Convert tree into Sum Tree
    solve(root, mp)


# Function to print inorder traversal
def printInorder(root):

    # Base case
    if root is None:
        return

    printInorder(root.left)

    print(root.data, end=" ")

    printInorder(root.right)


# Driver code

if __name__ == "__main__":
    # Constructing the tree
    root = Node(10)

    root.left = Node(-2)
    root.right = Node(6)

    root.left.left = Node(8)
    root.left.right = Node(-4)

    root.right.left = Node(7)
    root.right.right = Node(5)

    # Convert binary tree into Sum Tree
    toSumTree(root)

    # Print inorder traversal
    printInorder(root)
C#
using System;
using System.Collections.Generic;

// Structure of a tree node
class Node {
    public int data;
    public Node left;
    public Node right;

    public Node(int val)
    {
        data = val;
        left = right = null;
    }
}

class Program {
    // Function to store subtree sums
    static int storeSum(Node root, Dictionary<Node, int> map)
    {
        // Base case
        if (root == null)
            return 0;

        // Recursively get sum of left and right subtree
        int leftSum = storeSum(root.left, map);
        int rightSum = storeSum(root.right, map);

        // Total sum including current node
        int total = root.data + leftSum + rightSum;

        // Store subtree sum of current node in Dictionary
        map[root] = total;

        // Return subtree sum to parent
        return total;
    }

    // Helper function to convert tree into Sum Tree
    static void solve(Node root, Dictionary<Node, int> map)
    {
        // Base case
        if (root == null)
            return;

        // Get left subtree sum from Dictionary
        int leftSum = (root.left != null) ? map[root.left] : 0;

        // Get right subtree sum from Dictionary
        int rightSum = (root.right != null) ? map[root.right] : 0;

        // Update current node value
        root.data = leftSum + rightSum;

        // Recur for left and right subtree
        solve(root.left, map);
        solve(root.right, map);
    }

    // Function to convert tree into Sum Tree
    static void toSumTree(Node root)
    {
        // Base case
        if (root == null)
            return;

        // Dictionary to store subtree sums
        Dictionary<Node, int> map = new Dictionary<Node, int>();

        // Store subtree sums
        storeSum(root, map);

        // Convert tree into Sum Tree
        solve(root, map);
    }

    // Function to print inorder traversal
    static void printInorder(Node root)
    {
        // Base case
        if (root == null)
            return;

        printInorder(root.left);

        Console.Write(root.data + " ");

        printInorder(root.right);
    }

    // Driver code
    static void Main()
    {
        // Constructing the tree
        Node root = new Node(10);

        root.left = new Node(-2);
        root.right = new Node(6);

        root.left.left = new Node(8);
        root.left.right = new Node(-4);

        root.right.left = new Node(7);
        root.right.right = new Node(5);

        // Convert binary tree into Sum Tree
        toSumTree(root);

        // Print inorder traversal
        printInorder(root);
    }
}
JavaScript
// Structure of a tree node
class Node {

    constructor(val)
    {
        this.data = val;
        this.left = null;
        this.right = null;
    }
}

// Function to store subtree sums
function storeSum(root, map)
{
    // Base case
    if (root === null)
        return 0;

    // Recursively get sum of left and right subtree
    let leftSum = storeSum(root.left, map);
    let rightSum = storeSum(root.right, map);

    // Total sum including current node
    let total = root.data + leftSum + rightSum;

    // Store subtree sum of current node in Map
    map.set(root, total);

    // Return subtree sum to parent
    return total;
}

// Helper function to convert tree into Sum Tree
function solve(root, map)
{
    // Base case
    if (root === null)
        return;

    // Get left subtree sum from Map
    let leftSum = root.left ? map.get(root.left) : 0;

    // Get right subtree sum from Map
    let rightSum = root.right ? map.get(root.right) : 0;

    // Update current node value
    root.data = leftSum + rightSum;

    // Recur for left and right subtree
    solve(root.left, map);
    solve(root.right, map);
}

// Function to convert tree into Sum Tree
function toSumTree(root)
{
    // Base case
    if (root === null)
        return;

    // Map to store subtree sums
    let map = new Map();

    // Store subtree sums
    storeSum(root, map);

    // Convert tree into Sum Tree
    solve(root, map);
}

// Function to print inorder traversal
function printInorder(root, res)
{
    // Base case
    if (root === null)
        return res;

    // Traverse left subtree
    res = printInorder(root.left, res);

    // Add current node value
    res += root.data + " ";

    // Traverse right subtree
    res = printInorder(root.right, res);

    return res;
}

// Driver code

// Constructing the tree
let root = new Node(10);

root.left = new Node(-2);
root.right = new Node(6);

root.left.left = new Node(8);
root.left.right = new Node(-4);

root.right.left = new Node(7);
root.right.right = new Node(5);

// Convert binary tree into Sum Tree
toSumTree(root);

// Print inorder traversal
res = "";

res = printInorder(root, res);

console.log(res.trim());

Output
 0 4 0 20 0 12 0

[Alternative Approach] Using Postorder Traversal - O(n) Time and O(h) Space

The idea is to use postorder traversal (Left → Right → Node) so that we first compute the sum of left and right subtrees before updating the current node. While doing this, we keep track of the original value to return the correct subtree sum upward.

  • Traverse tree using postorder (left → right → root)
  • For each node, store old value, recursively get sum of left subtree, recursively get sum of right subtree and update node value = left sum + right sum
  • Return (new node value + old value) to parent
C++
#include <iostream>
using namespace std;

struct Node
{
    int data;
    Node *left;
    Node *right;

    Node(int val)
    {
        data = val;
        left = nullptr;
        right = nullptr;
    }
};

// Helper function to convert a given tree to sum tree
int solve(Node *node)
{
    // Base case
    if (node == NULL)
        return 0;

    // Store the old value
    int old_val = node->data;

    // Recursively call for left and
    // right subtrees and store the sum as
    // new value of this node
    node->data = solve(node->left) + solve(node->right);

    // Return the sum of values of nodes
    // in left and right subtrees and
    // old value of this node
    return node->data + old_val;
}

// Convert a given tree to a tree where
// every node contains sum of values of
// nodes in left and right subtrees in the original tree
void toSumTree(Node *node)
{
    solve(node);
}

// A function to print
// inorder traversal of a Binary Tree
void printInorder(Node* node)
{
    if (node == NULL)
        return;

    printInorder(node->left);
    cout << " " << node->data;
    printInorder(node->right);
}

int main()
{
    Node *root = NULL;

    // Constructing tree given in the above figure
    root = new Node(10);
    root->left = new Node(-2);
    root->right = new Node(6);
    root->left->left = new Node(8);
    root->left->right = new Node(-4);
    root->right->left = new Node(7);
    root->right->right = new Node(5);

    toSumTree(root);
    printInorder(root);

    return 0;
}
C
#include <stdio.h>
#include <stdlib.h>

struct Node
{
    int data;
    struct Node *left;
    struct Node *right;
};

// Function to create new node
struct Node *newNode(int val)
{
    struct Node *node = (struct Node *)malloc(sizeof(struct Node));
    node->data = val;
    node->left = NULL;
    node->right = NULL;
    return node;
}

// Helper function to convert a given tree to sum tree
int solve(struct Node *node)
{
    // Base case
    if (node == NULL)
        return 0;

    // Store the old value
    int old_val = node->data;

    // Recursively call for left and right subtrees
    node->data = solve(node->left) + solve(node->right);

    // Return total sum
    return node->data + old_val;
}

// Convert tree to sum tree
void toSumTree(struct Node *node)
{
    solve(node);
}

// Inorder traversal
void printInorder(struct Node *node)
{
    if (node == NULL)
        return;

    printInorder(node->left);
    printf(" %d", node->data);
    printInorder(node->right);
}

// Driver code
int main()
{
    struct Node *root = NULL;

    root = newNode(10);
    root->left = newNode(-2);
    root->right = newNode(6);
    root->left->left = newNode(8);
    root->left->right = newNode(-4);
    root->right->left = newNode(7);
    root->right->right = newNode(5);

    toSumTree(root);
    printInorder(root);

    return 0;
}
Java
class GFG {

    // A tree node structure
    static class Node {
        int data;
        Node left, right;

        Node(int val)
        {
            data = val;
            left = null;
            right = null;
        }
    }

    // Helper function to convert a given tree to sum tree
    static int solve(Node node)
    {
        // Base case
        if (node == null)
            return 0;

        // Store the old value
        int old_val = node.data;

        // Recursively call for left and
        // right subtrees and store the sum as
        // new value of this node
        node.data = solve(node.left) + solve(node.right);

        // Return the sum of values of nodes
        // in left and right subtrees and
        // old value of this node
        return node.data + old_val;
    }

    // Convert a given tree to a sum tree
    static void toSumTree(Node node) { solve(node); }

    // Inorder traversal
    static void printInorder(Node node)
    {
        if (node == null)
            return;

        printInorder(node.left);
        System.out.print(node.data + " ");
        printInorder(node.right);
    }

    // Driver code
    public static void main(String[] args)
    {
        Node root = new Node(10);
        root.left = new Node(-2);
        root.right = new Node(6);
        root.left.left = new Node(8);
        root.left.right = new Node(-4);
        root.right.left = new Node(7);
        root.right.right = new Node(5);

        toSumTree(root);
        printInorder(root);
    }
}
Python
# A tree node structure
class Node:
    def __init__(self, val):
        self.data = val
        self.left = None
        self.right = None

# Helper function to convert a given tree to sum tree
def solve(node):
    # Base case
    if node is None:
        return 0

    # Store the old value
    old_val = node.data

    # Recursively call for left and right subtrees
    node.data = solve(node.left) + solve(node.right)

    # Return total sum
    return node.data + old_val

# Convert tree to sum tree
def toSumTree(node):
    solve(node)

# Inorder traversal
def printInorder(node):
    if node is None:
        return

    printInorder(node.left)
    print(node.data, end=" ")
    printInorder(node.right)


# Driver code
if __name__ == "__main__":
    root = Node(10)
    root.left = Node(-2)
    root.right = Node(6)
    root.left.left = Node(8)
    root.left.right = Node(-4)
    root.right.left = Node(7)
    root.right.right = Node(5)

    toSumTree(root)
    printInorder(root)
C#
using System;

class Node {
    public int data;
    public Node left;
    public Node right;

    public Node(int val)
    {
        data = val;
        left = null;
        right = null;
    }
}

class Program {
    // Helper function to convert a given tree to Sum Tree
    static int Solve(Node node)
    {
        // Base case
        if (node == null)
            return 0;

        // Store the old value of current node
        int oldVal = node.data;

        // Recursively calculate sum of left and right
        // subtree and update current node value
        node.data = Solve(node.left) + Solve(node.right);

        // Return:
        // left subtree sum + right subtree sum + old node
        // value This total sum is used by parent node
        return node.data + oldVal;
    }

    // Function to convert binary tree into Sum Tree
    static void toSumTree(Node node) { Solve(node); }

    // function to print inorder traversal
    static void PrintInorder(Node node)
    {
        // Base case
        if (node == null)
            return;

        // Traverse left subtree
        PrintInorder(node.left);

        // Print current node
        Console.Write(" " + node.data);

        // Traverse right subtree
        PrintInorder(node.right);
    }

    static void Main()
    {
        Node root = null;

        // Constructing tree
        root = new Node(10);

        root.left = new Node(-2);
        root.right = new Node(6);

        root.left.left = new Node(8);
        root.left.right = new Node(-4);

        root.right.left = new Node(7);
        root.right.right = new Node(5);

        // Convert tree into Sum Tree
        toSumTree(root);

        // Print inorder traversal of Sum Tree
        PrintInorder(root);
    }
}
JavaScript
class Node {

    constructor(val)
    {
        this.data = val;
        this.left = null;
        this.right = null;
    }
}

// Helper function to convert a given tree to Sum Tree
function solve(root)
{
    // Base case
    if (root === null)
        return 0;

    // Store the old value of current root
    let oldVal = root.data;

    // Recursively calculate sum of left and right subtree
    // and update current root value
    root.data = solve(root.left) + solve(root.right);

    // Return:
    // left subtree sum + right subtree sum + old root value
    // This total sum is used by parent root
    return root.data + oldVal;
}

// Function to convert binary tree into Sum Tree
function toSumTree(root) { solve(root); }

// Function to print inorder traversal
function printInorder(root, res)
{
    // Base case
    if (root === null)
        return res;

    // Traverse left subtree
    res = printInorder(root.left, res);

    // Print current root
    res += root.data + " ";

    // Traverse right subtree
    res = printInorder(root.right, res);

    return res;
}

// Driver Code

// Constructing tree
let root = new Node(10);

root.left = new Node(-2);
root.right = new Node(6);

root.left.left = new Node(8);
root.left.right = new Node(-4);

root.right.left = new Node(7);
root.right.right = new Node(5);

// Convert tree into Sum Tree
toSumTree(root);

// Print inorder traversal of Sum Tree
let res = "";

res = printInorder(root, res);

console.log(res.trim());

Output
 0 4 0 20 0 12 0
Comment