Given a Binary Tree, convert it to a tree where each node contains the sum of the left and right sub trees in the original tree. The values of leaf nodes are changed to 0.
Input:
Output:
Try It Yourself
Table of Content
Using Hash Map or Dictionary - O(n) Time and O(n) Space
First store the sum of each subtree in a HashMap using one traversal, then in a second traversal update each node using the stored values.
- Create a
map<Node*, int>to store subtree sums - First traversal: Compute sum of left + right + node and Store in map :
- For each node: New value = (sum of left subtree + sum of right subtree)
- Return updated tree.
#include <iostream>
#include <unordered_map>
using namespace std;
// Structure of a tree node
struct Node
{
int data;
Node *left;
Node *right;
Node(int val)
{
data = val;
left = right = NULL;
}
};
// Function to store subtree sums
int storeSum(Node *root, unordered_map<Node *, int> &mp)
{
// Base case
if (root == NULL)
return 0;
// Recursively get sum of left and right subtree
int leftSum = storeSum(root->left, mp);
int rightSum = storeSum(root->right, mp);
// Total sum including current node
int total = root->data + leftSum + rightSum;
// Store this sum in map
mp[root] = total;
// Return total sum to parent
return total;
}
// Helper function
void solve(Node *root, unordered_map<Node *, int> &mp)
{
if (root == NULL)
{
return;
}
// Get left subtree sum from map
int leftSum = (root->left ? mp[root->left] : 0);
// Get right subtree sum from map
int rightSum = (root->right ? mp[root->right] : 0);
// Update current node value
root->data = leftSum + rightSum;
// Recur for left and right subtree
solve(root->left, mp);
solve(root->right, mp);
}
// Function to convert tree into sum tree
void toSumTree(Node *root)
{
// Base case
if (root == NULL)
return;
// HashMap to store subtree sum of each node
unordered_map<Node *, int> mp;
// store subtree sum
storeSum(root, mp);
// convert tree to sum tree
solve(root, mp);
}
// Function to print inorder traversal
void printInorder(Node *root)
{
if (root == nullptr)
return;
printInorder(root->left);
cout << " " << root->data;
printInorder(root->right);
}
// Driver code
int main()
{
// Constructing the tree
Node *root = new Node(10);
root->left = new Node(-2);
root->right = new Node(6);
root->left->left = new Node(8);
root->left->right = new Node(-4);
root->right->left = new Node(7);
root->right->right = new Node(5);
toSumTree(root);
printInorder(root);
return 0;
}
import java.util.HashMap;
// Structure of a tree node
class Node {
int data;
Node left;
Node right;
Node(int val)
{
data = val;
left = right = null;
}
}
class Main {
// Function to store subtree sums
static int storeSum(Node root, HashMap<Node, Integer> map)
{
// Base case
if (root == null)
return 0;
// Recursively get sum of left and right subtree
int leftSum = storeSum(root.left, map);
int rightSum = storeSum(root.right, map);
// Total sum including current node
int total = root.data + leftSum + rightSum;
// Store subtree sum of current node in HashMap
map.put(root, total);
// Return subtree sum to parent
return total;
}
// Helper function to convert tree into Sum Tree
static void solve(Node root, HashMap<Node, Integer> map)
{
// Base case
if (root == null)
return;
// Get left subtree sum from HashMap
int leftSum
= (root.left != null) ? map.get(root.left) : 0;
// Get right subtree sum from HashMap
int rightSum = (root.right != null) ? map.get(root.right) : 0;
// Update current node value
root.data = leftSum + rightSum;
// Recur for left and right subtree
solve(root.left, map);
solve(root.right, map);
}
// Function to convert tree into Sum Tree
static void toSumTree(Node root)
{
// Base case
if (root == null)
return;
// HashMap to store subtree sum of each node
HashMap<Node, Integer> map = new HashMap<>();
// Store subtree sums
storeSum(root, map);
// Convert tree into Sum Tree
solve(root, map);
}
// Function to print inorder traversal
static void printInorder(Node root)
{
// Base case
if (root == null)
return;
printInorder(root.left);
System.out.print(root.data + " ");
printInorder(root.right);
}
// Driver code
public static void main(String[] args)
{
// Constructing the tree
Node root = new Node(10);
root.left = new Node(-2);
root.right = new Node(6);
root.left.left = new Node(8);
root.left.right = new Node(-4);
root.right.left = new Node(7);
root.right.right = new Node(5);
// Convert binary tree into Sum Tree
toSumTree(root);
// Print inorder traversal
printInorder(root);
}
}
# Structure of a tree node
class Node:
def __init__(self, val):
self.data = val
self.left = None
self.right = None
# Function to store subtree sums
def storeSum(root, mp):
# Base case
if root is None:
return 0
# Recursively get sum of left and right subtree
leftSum = storeSum(root.left, mp)
rightSum = storeSum(root.right, mp)
# Total sum including current node
total = root.data + leftSum + rightSum
# Store subtree sum of current node in dictionary
mp[root] = total
# Return subtree sum to parent
return total
# Helper function to convert tree into Sum Tree
def solve(root, mp):
# Base case
if root is None:
return
# Get left subtree sum from dictionary
leftSum = mp[root.left] if root.left else 0
# Get right subtree sum from dictionary
rightSum = mp[root.right] if root.right else 0
# Update current node value
root.data = leftSum + rightSum
# Recur for left and right subtree
solve(root.left, mp)
solve(root.right, mp)
# Function to convert tree into Sum Tree
def toSumTree(root):
# Base case
if root is None:
return
# Dictionary to store subtree sums
mp = {}
# Store subtree sums
storeSum(root, mp)
# Convert tree into Sum Tree
solve(root, mp)
# Function to print inorder traversal
def printInorder(root):
# Base case
if root is None:
return
printInorder(root.left)
print(root.data, end=" ")
printInorder(root.right)
# Driver code
if __name__ == "__main__":
# Constructing the tree
root = Node(10)
root.left = Node(-2)
root.right = Node(6)
root.left.left = Node(8)
root.left.right = Node(-4)
root.right.left = Node(7)
root.right.right = Node(5)
# Convert binary tree into Sum Tree
toSumTree(root)
# Print inorder traversal
printInorder(root)
using System;
using System.Collections.Generic;
// Structure of a tree node
class Node {
public int data;
public Node left;
public Node right;
public Node(int val)
{
data = val;
left = right = null;
}
}
class Program {
// Function to store subtree sums
static int storeSum(Node root, Dictionary<Node, int> map)
{
// Base case
if (root == null)
return 0;
// Recursively get sum of left and right subtree
int leftSum = storeSum(root.left, map);
int rightSum = storeSum(root.right, map);
// Total sum including current node
int total = root.data + leftSum + rightSum;
// Store subtree sum of current node in Dictionary
map[root] = total;
// Return subtree sum to parent
return total;
}
// Helper function to convert tree into Sum Tree
static void solve(Node root, Dictionary<Node, int> map)
{
// Base case
if (root == null)
return;
// Get left subtree sum from Dictionary
int leftSum = (root.left != null) ? map[root.left] : 0;
// Get right subtree sum from Dictionary
int rightSum = (root.right != null) ? map[root.right] : 0;
// Update current node value
root.data = leftSum + rightSum;
// Recur for left and right subtree
solve(root.left, map);
solve(root.right, map);
}
// Function to convert tree into Sum Tree
static void toSumTree(Node root)
{
// Base case
if (root == null)
return;
// Dictionary to store subtree sums
Dictionary<Node, int> map = new Dictionary<Node, int>();
// Store subtree sums
storeSum(root, map);
// Convert tree into Sum Tree
solve(root, map);
}
// Function to print inorder traversal
static void printInorder(Node root)
{
// Base case
if (root == null)
return;
printInorder(root.left);
Console.Write(root.data + " ");
printInorder(root.right);
}
// Driver code
static void Main()
{
// Constructing the tree
Node root = new Node(10);
root.left = new Node(-2);
root.right = new Node(6);
root.left.left = new Node(8);
root.left.right = new Node(-4);
root.right.left = new Node(7);
root.right.right = new Node(5);
// Convert binary tree into Sum Tree
toSumTree(root);
// Print inorder traversal
printInorder(root);
}
}
// Structure of a tree node
class Node {
constructor(val)
{
this.data = val;
this.left = null;
this.right = null;
}
}
// Function to store subtree sums
function storeSum(root, map)
{
// Base case
if (root === null)
return 0;
// Recursively get sum of left and right subtree
let leftSum = storeSum(root.left, map);
let rightSum = storeSum(root.right, map);
// Total sum including current node
let total = root.data + leftSum + rightSum;
// Store subtree sum of current node in Map
map.set(root, total);
// Return subtree sum to parent
return total;
}
// Helper function to convert tree into Sum Tree
function solve(root, map)
{
// Base case
if (root === null)
return;
// Get left subtree sum from Map
let leftSum = root.left ? map.get(root.left) : 0;
// Get right subtree sum from Map
let rightSum = root.right ? map.get(root.right) : 0;
// Update current node value
root.data = leftSum + rightSum;
// Recur for left and right subtree
solve(root.left, map);
solve(root.right, map);
}
// Function to convert tree into Sum Tree
function toSumTree(root)
{
// Base case
if (root === null)
return;
// Map to store subtree sums
let map = new Map();
// Store subtree sums
storeSum(root, map);
// Convert tree into Sum Tree
solve(root, map);
}
// Function to print inorder traversal
function printInorder(root, res)
{
// Base case
if (root === null)
return res;
// Traverse left subtree
res = printInorder(root.left, res);
// Add current node value
res += root.data + " ";
// Traverse right subtree
res = printInorder(root.right, res);
return res;
}
// Driver code
// Constructing the tree
let root = new Node(10);
root.left = new Node(-2);
root.right = new Node(6);
root.left.left = new Node(8);
root.left.right = new Node(-4);
root.right.left = new Node(7);
root.right.right = new Node(5);
// Convert binary tree into Sum Tree
toSumTree(root);
// Print inorder traversal
res = "";
res = printInorder(root, res);
console.log(res.trim());
Output
0 4 0 20 0 12 0
[Alternative Approach] Using Postorder Traversal - O(n) Time and O(h) Space
The idea is to use postorder traversal (Left → Right → Node) so that we first compute the sum of left and right subtrees before updating the current node. While doing this, we keep track of the original value to return the correct subtree sum upward.
- Traverse tree using postorder (left → right → root)
- For each node, store old value, recursively get sum of left subtree, recursively get sum of right subtree and update node value = left sum + right sum
- Return (new node value + old value) to parent
#include <iostream>
using namespace std;
struct Node
{
int data;
Node *left;
Node *right;
Node(int val)
{
data = val;
left = nullptr;
right = nullptr;
}
};
// Helper function to convert a given tree to sum tree
int solve(Node *node)
{
// Base case
if (node == NULL)
return 0;
// Store the old value
int old_val = node->data;
// Recursively call for left and
// right subtrees and store the sum as
// new value of this node
node->data = solve(node->left) + solve(node->right);
// Return the sum of values of nodes
// in left and right subtrees and
// old value of this node
return node->data + old_val;
}
// Convert a given tree to a tree where
// every node contains sum of values of
// nodes in left and right subtrees in the original tree
void toSumTree(Node *node)
{
solve(node);
}
// A function to print
// inorder traversal of a Binary Tree
void printInorder(Node* node)
{
if (node == NULL)
return;
printInorder(node->left);
cout << " " << node->data;
printInorder(node->right);
}
int main()
{
Node *root = NULL;
// Constructing tree given in the above figure
root = new Node(10);
root->left = new Node(-2);
root->right = new Node(6);
root->left->left = new Node(8);
root->left->right = new Node(-4);
root->right->left = new Node(7);
root->right->right = new Node(5);
toSumTree(root);
printInorder(root);
return 0;
}
#include <stdio.h>
#include <stdlib.h>
struct Node
{
int data;
struct Node *left;
struct Node *right;
};
// Function to create new node
struct Node *newNode(int val)
{
struct Node *node = (struct Node *)malloc(sizeof(struct Node));
node->data = val;
node->left = NULL;
node->right = NULL;
return node;
}
// Helper function to convert a given tree to sum tree
int solve(struct Node *node)
{
// Base case
if (node == NULL)
return 0;
// Store the old value
int old_val = node->data;
// Recursively call for left and right subtrees
node->data = solve(node->left) + solve(node->right);
// Return total sum
return node->data + old_val;
}
// Convert tree to sum tree
void toSumTree(struct Node *node)
{
solve(node);
}
// Inorder traversal
void printInorder(struct Node *node)
{
if (node == NULL)
return;
printInorder(node->left);
printf(" %d", node->data);
printInorder(node->right);
}
// Driver code
int main()
{
struct Node *root = NULL;
root = newNode(10);
root->left = newNode(-2);
root->right = newNode(6);
root->left->left = newNode(8);
root->left->right = newNode(-4);
root->right->left = newNode(7);
root->right->right = newNode(5);
toSumTree(root);
printInorder(root);
return 0;
}
class GFG {
// A tree node structure
static class Node {
int data;
Node left, right;
Node(int val)
{
data = val;
left = null;
right = null;
}
}
// Helper function to convert a given tree to sum tree
static int solve(Node node)
{
// Base case
if (node == null)
return 0;
// Store the old value
int old_val = node.data;
// Recursively call for left and
// right subtrees and store the sum as
// new value of this node
node.data = solve(node.left) + solve(node.right);
// Return the sum of values of nodes
// in left and right subtrees and
// old value of this node
return node.data + old_val;
}
// Convert a given tree to a sum tree
static void toSumTree(Node node) { solve(node); }
// Inorder traversal
static void printInorder(Node node)
{
if (node == null)
return;
printInorder(node.left);
System.out.print(node.data + " ");
printInorder(node.right);
}
// Driver code
public static void main(String[] args)
{
Node root = new Node(10);
root.left = new Node(-2);
root.right = new Node(6);
root.left.left = new Node(8);
root.left.right = new Node(-4);
root.right.left = new Node(7);
root.right.right = new Node(5);
toSumTree(root);
printInorder(root);
}
}
# A tree node structure
class Node:
def __init__(self, val):
self.data = val
self.left = None
self.right = None
# Helper function to convert a given tree to sum tree
def solve(node):
# Base case
if node is None:
return 0
# Store the old value
old_val = node.data
# Recursively call for left and right subtrees
node.data = solve(node.left) + solve(node.right)
# Return total sum
return node.data + old_val
# Convert tree to sum tree
def toSumTree(node):
solve(node)
# Inorder traversal
def printInorder(node):
if node is None:
return
printInorder(node.left)
print(node.data, end=" ")
printInorder(node.right)
# Driver code
if __name__ == "__main__":
root = Node(10)
root.left = Node(-2)
root.right = Node(6)
root.left.left = Node(8)
root.left.right = Node(-4)
root.right.left = Node(7)
root.right.right = Node(5)
toSumTree(root)
printInorder(root)
using System;
class Node {
public int data;
public Node left;
public Node right;
public Node(int val)
{
data = val;
left = null;
right = null;
}
}
class Program {
// Helper function to convert a given tree to Sum Tree
static int Solve(Node node)
{
// Base case
if (node == null)
return 0;
// Store the old value of current node
int oldVal = node.data;
// Recursively calculate sum of left and right
// subtree and update current node value
node.data = Solve(node.left) + Solve(node.right);
// Return:
// left subtree sum + right subtree sum + old node
// value This total sum is used by parent node
return node.data + oldVal;
}
// Function to convert binary tree into Sum Tree
static void toSumTree(Node node) { Solve(node); }
// function to print inorder traversal
static void PrintInorder(Node node)
{
// Base case
if (node == null)
return;
// Traverse left subtree
PrintInorder(node.left);
// Print current node
Console.Write(" " + node.data);
// Traverse right subtree
PrintInorder(node.right);
}
static void Main()
{
Node root = null;
// Constructing tree
root = new Node(10);
root.left = new Node(-2);
root.right = new Node(6);
root.left.left = new Node(8);
root.left.right = new Node(-4);
root.right.left = new Node(7);
root.right.right = new Node(5);
// Convert tree into Sum Tree
toSumTree(root);
// Print inorder traversal of Sum Tree
PrintInorder(root);
}
}
class Node {
constructor(val)
{
this.data = val;
this.left = null;
this.right = null;
}
}
// Helper function to convert a given tree to Sum Tree
function solve(root)
{
// Base case
if (root === null)
return 0;
// Store the old value of current root
let oldVal = root.data;
// Recursively calculate sum of left and right subtree
// and update current root value
root.data = solve(root.left) + solve(root.right);
// Return:
// left subtree sum + right subtree sum + old root value
// This total sum is used by parent root
return root.data + oldVal;
}
// Function to convert binary tree into Sum Tree
function toSumTree(root) { solve(root); }
// Function to print inorder traversal
function printInorder(root, res)
{
// Base case
if (root === null)
return res;
// Traverse left subtree
res = printInorder(root.left, res);
// Print current root
res += root.data + " ";
// Traverse right subtree
res = printInorder(root.right, res);
return res;
}
// Driver Code
// Constructing tree
let root = new Node(10);
root.left = new Node(-2);
root.right = new Node(6);
root.left.left = new Node(8);
root.left.right = new Node(-4);
root.right.left = new Node(7);
root.right.right = new Node(5);
// Convert tree into Sum Tree
toSumTree(root);
// Print inorder traversal of Sum Tree
let res = "";
res = printInorder(root, res);
console.log(res.trim());
Output
0 4 0 20 0 12 0

