Given the root of Binary Search Tree, transform it into a greater sum tree where each node contains the sum of all nodes greater than that node.
Example:
Input:
Output: [[119], [137, 75], [139, 130, 104, 0], [N, N, N, N, N, N, 40, N]] Explanation: Every Node contains the sum of nodes greater then current node's value.
[Naive Approach] By Calculating Sum for Each Node - O(n^2) Time and O(n) Space
The idea is to traverse the tree, and for each node, calculate the sum of all nodes greater than the current node and store this sum for that node. Then, traverse the tree again and replace each node’s value with its corresponding sum.
C++
//Driver Code Starts#include<iostream>#include<queue>#include<unordered_map>usingnamespacestd;// Node structureclassNode{public:intdata;Node*left;Node*right;Node(intvalue){data=value;left=nullptr;right=nullptr;}};// Calculate HeightintgetHeight(Node*root,inth){if(root==nullptr)returnh-1;returnmax(getHeight(root->left,h+1),getHeight(root->right,h+1));}// Print Level OrdervoidlevelOrder(Node*root){queue<pair<Node*,int>>q;q.push({root,0});intlastLevel=0;// function to get the height of treeintheight=getHeight(root,0);// printing the level order of treewhile(!q.empty()){auto[node,lvl]=q.front();q.pop();if(lvl>lastLevel){cout<<"";lastLevel=lvl;}// all levels are printedif(lvl>height)break;// printing null nodeif(node->data!=-1)cout<<node->data<<" ";elsecout<<"N ";// null node has no childrenif(node->data==-1)continue;if(node->left==nullptr)q.push({newNode(-1),lvl+1});elseq.push({node->left,lvl+1});if(node->right==nullptr)q.push({newNode(-1),lvl+1});elseq.push({node->right,lvl+1});}}//Driver Code Ends// Function to find nodes having greater value than current node.voidfindGreaterNodes(Node*root,Node*curr,unordered_map<Node*,int>&map){if(root==nullptr)return;// if value is greater than node, then increment // it in the mapif(root->data>curr->data)map[curr]+=root->data;findGreaterNodes(root->left,curr,map);findGreaterNodes(root->right,curr,map);}// update the node value with sum of nodesvoidtransformToGreaterSumTree(Node*curr,Node*root,unordered_map<Node*,int>&map){if(curr==nullptr){return;}// Find all nodes greater than current nodefindGreaterNodes(root,curr,map);// Recursively check for left and right subtree.transformToGreaterSumTree(curr->left,root,map);transformToGreaterSumTree(curr->right,root,map);}// Function to update value of each node.voidpreOrderTrav(Node*root,unordered_map<Node*,int>&map){if(root==nullptr)return;root->data=map[root];preOrderTrav(root->left,map);preOrderTrav(root->right,map);}voidtransformTree(Node*root){// map to store greater sum for each node.unordered_map<Node*,int>map;transformToGreaterSumTree(root,root,map);// update the value of nodespreOrderTrav(root,map);}//Driver Code Startsintmain(){// Constructing the BST// 11// / \ // 2 29// / \ / \ // 1 7 15 40// \ // 50Node*root=newNode(11);root->left=newNode(2);root->right=newNode(29);root->left->left=newNode(1);root->left->right=newNode(7);root->right->left=newNode(15);root->right->right=newNode(40);root->right->right->right=newNode(50);transformTree(root);levelOrder(root);return0;}//Driver Code Ends
Java
//Driver Code Startsimportjava.util.HashMap;importjava.util.List;importjava.util.Queue;importjava.util.LinkedList;classNode{intdata;Nodeleft,right;Node(intvalue){data=value;left=null;right=null;}}classGfG{// Calculate HeightstaticintgetHeight(Noderoot,inth){if(root==null)returnh-1;returnMath.max(getHeight(root.left,h+1),getHeight(root.right,h+1));}// Print Level OrderstaticvoidlevelOrder(Noderoot){Queue<List<Object>>queue=newLinkedList<>();queue.offer(List.of(root,0));intlastLevel=0;// function to get the height of treeintheight=getHeight(root,0);// printing the level order of treewhile(!queue.isEmpty()){List<Object>top=queue.poll();Nodenode=(Node)top.get(0);intlvl=(int)top.get(1);if(lvl>lastLevel){System.out.println();lastLevel=lvl;}// all levels are printedif(lvl>height)break;// printing null nodeSystem.out.print((node.data==-1?"N":node.data)+" ");// null node has no childrenif(node.data==-1)continue;if(node.left==null)queue.offer(List.of(newNode(-1),lvl+1));elsequeue.offer(List.of(node.left,lvl+1));if(node.right==null)queue.offer(List.of(newNode(-1),lvl+1));elsequeue.offer(List.of(node.right,lvl+1));}}//Driver Code Ends// Function to find nodes having greater value than// current node.staticvoidfindGreaterNodes(Noderoot,Nodecurr,HashMap<Node,Integer>map){if(root==null)return;// if value is greater than node, then increment// it in the mapif(root.data>curr.data)map.put(curr,map.getOrDefault(curr,0)+root.data);findGreaterNodes(root.left,curr,map);findGreaterNodes(root.right,curr,map);}staticvoidtransformToGreaterSumTree(Nodecurr,Noderoot,HashMap<Node,Integer>map){if(curr==null){return;}// Find all nodes greater than current nodefindGreaterNodes(root,curr,map);// Recursively check for left and right subtree.transformToGreaterSumTree(curr.left,root,map);transformToGreaterSumTree(curr.right,root,map);}// Function to update value of each node.staticvoidpreOrderTrav(Noderoot,HashMap<Node,Integer>map){if(root==null)return;root.data=map.getOrDefault(root,0);preOrderTrav(root.left,map);preOrderTrav(root.right,map);}staticvoidtransformTree(Noderoot){// map to store greater sum for each node.HashMap<Node,Integer>map=newHashMap<>();transformToGreaterSumTree(root,root,map);// update the value of nodespreOrderTrav(root,map);}//Driver Code Startspublicstaticvoidmain(String[]args){// Constructing the BST// 11// / \// 2 29// / \ / \// 1 7 15 40// \// 50Noderoot=newNode(11);root.left=newNode(2);root.right=newNode(29);root.left.left=newNode(1);root.left.right=newNode(7);root.right.left=newNode(15);root.right.right=newNode(40);root.right.right.right=newNode(50);transformTree(root);levelOrder(root);}}//Driver Code Ends
Python
#Driver Code StartsfromcollectionsimportdequeclassNode:def__init__(self,value):self.data=valueself.left=Noneself.right=None# Calculate heightdefgetHeight(root,h):ifrootisNone:returnh-1returnmax(getHeight(root.left,h+1),getHeight(root.right,h+1))# Print Level OrderdeflevelOrder(root):queue=deque([[root,0]])lastLevel=0# function to get the height of treeheight=getHeight(root,0)# printing the level order of treewhilequeue:node,lvl=queue.popleft()iflvl>lastLevel:print()lastLevel=lvl# all levels are printediflvl>height:break# printing null nodeprint("N"ifnode.data==-1elsenode.data,end=" ")# null node has no childrenifnode.data==-1:continueifnode.leftisNone:queue.append([Node(-1),lvl+1])else:queue.append([node.left,lvl+1])ifnode.rightisNone:queue.append([Node(-1),lvl+1])else:queue.append([node.right,lvl+1])#Driver Code Ends# Function to find nodes having greater # value than current node.deffindGreaterNodes(root,curr,map):ifrootisNone:return# if value is greater than node, then increment# it in the mapifroot.data>curr.data:map[curr]+=root.datafindGreaterNodes(root.left,curr,map)findGreaterNodes(root.right,curr,map)deftransformToGreaterSumTree(curr,root,map):ifcurrisNone:return# Find all nodes greater than current nodemap[curr]=0findGreaterNodes(root,curr,map)# Recursively check for left and right subtree.transformToGreaterSumTree(curr.left,root,map)transformToGreaterSumTree(curr.right,root,map)# Function to update value of each node.defpreOrderTrav(root,map):ifrootisNone:returnroot.data=map.get(root,root.data)preOrderTrav(root.left,map)preOrderTrav(root.right,map)deftransformTree(root):# map to store greater sum for each node.map={}transformToGreaterSumTree(root,root,map)# update the value of nodespreOrderTrav(root,map)#Driver Code Startsif__name__=="__main__":# Constructing the BST# 11# / \# 2 29# / \ / \# 1 7 15 40# \# 50root=Node(11)root.left=Node(2)root.right=Node(29)root.left.left=Node(1)root.left.right=Node(7)root.right.left=Node(15)root.right.right=Node(40)root.right.right.right=Node(50)transformTree(root)levelOrder(root)#Driver Code Ends
C#
//Driver Code Starts// C# program to transform a BST// to sum treeusingSystem;usingSystem.Collections.Generic;classNode{publicintdata;publicNodeleft,right;publicNode(intvalue){data=value;left=null;right=null;}}classGfG{// Calculate heightstaticintgetHeight(Noderoot,inth){if(root==null)returnh-1;returnMath.Max(getHeight(root.left,h+1),getHeight(root.right,h+1));}// Print Level OrderstaticvoidlevelOrder(Noderoot){Queue<(Node,int)>queue=newQueue<(Node,int)>();queue.Enqueue((root,0));intlastLevel=0;// function to get the height of treeintheight=getHeight(root,0);// printing the level order of treewhile(queue.Count>0){var(node,lvl)=queue.Dequeue();if(lvl>lastLevel){Console.WriteLine();lastLevel=lvl;}// all levels are printedif(lvl>height)break;// printing null nodeConsole.Write((node.data==-1?"N":node.data.ToString())+" ");// null node has no childrenif(node.data==-1)continue;if(node.left==null)queue.Enqueue((newNode(-1),lvl+1));elsequeue.Enqueue((node.left,lvl+1));if(node.right==null)queue.Enqueue((newNode(-1),lvl+1));elsequeue.Enqueue((node.right,lvl+1));}}//Driver Code Ends// Function to find nodes having greater value// than current node.staticvoidfindGreaterNodes(Noderoot,Nodecurr,Dictionary<Node,int>map){if(root==null)return;// if value is greater than node, then increment// it in the mapif(root.data>curr.data)map[curr]+=root.data;findGreaterNodes(root.left,curr,map);findGreaterNodes(root.right,curr,map);}staticvoidtransformToGreaterSumTree(Nodecurr,Noderoot,Dictionary<Node,int>map){if(curr==null){return;}// Find all nodes greater than// current nodemap[curr]=0;findGreaterNodes(root,curr,map);// Recursively check for left and right subtree.transformToGreaterSumTree(curr.left,root,map);transformToGreaterSumTree(curr.right,root,map);}// Function to update value of each node.staticvoidpreOrderTrav(Noderoot,Dictionary<Node,int>map){if(root==null)return;root.data=map.ContainsKey(root)?map[root]:root.data;preOrderTrav(root.left,map);preOrderTrav(root.right,map);}staticvoidtransformTree(Noderoot){// map to store greater sum for each node.Dictionary<Node,int>map=newDictionary<Node,int>();transformToGreaterSumTree(root,root,map);// update the value of nodespreOrderTrav(root,map);}//Driver Code StartsstaticvoidMain(string[]args){// Constructing the BST// 11// / \// 2 29// / \ / \// 1 7 15 40// \// 50Noderoot=newNode(11);root.left=newNode(2);root.right=newNode(29);root.left.left=newNode(1);root.left.right=newNode(7);root.right.left=newNode(15);root.right.right=newNode(40);root.right.right.right=newNode(50);transformTree(root);levelOrder(root);}}//Driver Code Ends
JavaScript
//Driver Code Starts// JavaScript program to transform // a BST to sum treeclassNode{constructor(value){this.data=value;this.left=null;this.right=null;}}// Calculate HeightfunctiongetHeight(root,h){if(root===null)returnh-1;returnMath.max(getHeight(root.left,h+1),getHeight(root.right,h+1));}// Print Level OrderfunctionlevelOrder(root){letqueue=[];queue.push([root,0]);letlastLevel=0;// function to get the height of treeletheight=getHeight(root,0);// printing the level order of treewhile(queue.length>0){let[node,lvl]=queue.shift();if(lvl>lastLevel){console.log("");lastLevel=lvl;}// all levels are printedif(lvl>height)break;// printing null nodeprocess.stdout.write((node.data===-1?"N":node.data)+" ");// null node has no childrenif(node.data===-1)continue;if(node.left===null)queue.push([newNode(-1),lvl+1]);elsequeue.push([node.left,lvl+1]);if(node.right===null)queue.push([newNode(-1),lvl+1]);elsequeue.push([node.right,lvl+1]);}}//Driver Code Ends// Function to find nodes having greater value // than current node.functionfindGreaterNodes(root,curr,map){if(root===null)return;// if value is greater than node, then increment// it in the mapif(root.data>curr.data){map.set(curr,(map.get(curr)||0)+root.data);}findGreaterNodes(root.left,curr,map);findGreaterNodes(root.right,curr,map);}functiontransformToGreaterSumTree(curr,root,map){if(curr===null){return;}// Find all nodes greater than current nodefindGreaterNodes(root,curr,map);// Recursively check for left and right subtree.transformToGreaterSumTree(curr.left,root,map);transformToGreaterSumTree(curr.right,root,map);}// Function to update value of each node.functionpreOrderTrav(root,map){if(root===null)return;root.data=map.has(root)?map.get(root):0;preOrderTrav(root.left,map);preOrderTrav(root.right,map);}functiontransformTree(root){// map to store greater sum for each node.constmap=newMap();transformToGreaterSumTree(root,root,map);// update the value of nodespreOrderTrav(root,map);}//Driver Code Starts// Constructing the BSTconstroot=newNode(11);root.left=newNode(2);root.right=newNode(29);root.left.left=newNode(1);root.left.right=newNode(7);root.right.left=newNode(15);root.right.right=newNode(40);root.right.right.right=newNode(50);transformTree(root);levelOrder(root);//Driver Code Ends
Output
134
152 90
154 145 119 50
N N N N N N N 0
[Expected Approach] Using Single Traversal - O(n) Time and O(n) Space
The idea is to optimize the above approach. Rather than calculating the sum for each node separately, we traverse the tree in reverse in-order (right → root → left) while keeping a running sum of all previously visited nodes. Each node’s value is updated to this running sum, ensuring it contains the sum of all nodes greater than itself.
C++
//Driver Code Starts#include<iostream>#include<queue>#include<unordered_map>usingnamespacestd;// Node structureclassNode{public:intdata;Node*left;Node*right;Node(intvalue){data=value;left=nullptr;right=nullptr;}};// Calculate heightintgetHeight(Node*root,inth){if(root==nullptr)returnh-1;returnmax(getHeight(root->left,h+1),getHeight(root->right,h+1));}// Print Level ordervoidlevelOrder(Node*root){queue<pair<Node*,int>>q;q.push({root,0});intlastLevel=0;// function to get the height of treeintheight=getHeight(root,0);// printing the level order of treewhile(!q.empty()){auto[node,lvl]=q.front();q.pop();if(lvl>lastLevel){cout<<"";lastLevel=lvl;}// all levels are printedif(lvl>height)break;// printing null nodeif(node->data!=-1)cout<<node->data<<" ";elsecout<<"N ";// null node has no childrenif(node->data==-1)continue;if(node->left==nullptr)q.push({newNode(-1),lvl+1});elseq.push({node->left,lvl+1});if(node->right==nullptr)q.push({newNode(-1),lvl+1});elseq.push({node->right,lvl+1});}}//Driver Code Ends// Function to update the tree voidupdateTree(Node*root,int&sum){if(root==nullptr){return;}// Traverse the right subtree first (larger values)updateTree(root->right,sum);// Update the sum and the current node's valuesum+=root->data;root->data=sum-root->data;// Traverse the left subtree (smaller values)updateTree(root->left,sum);}// Return the updated treevoidtransformTree(Node*root){// Initialize the cumulative sumintsum=0;updateTree(root,sum);}//Driver Code Startsintmain(){// Constructing the BST// 11// / \ // 2 29// / \ / \ // 1 7 15 40// /// 35 Node*root=newNode(11);root->left=newNode(2);root->right=newNode(29);root->left->left=newNode(1);root->left->right=newNode(7);root->right->left=newNode(15);root->right->right=newNode(40);root->right->right->left=newNode(35);transformTree(root);levelOrder(root);return0;}//Driver Code Ends
Java
//Driver Code Startsimportjava.util.List;importjava.util.Queue;importjava.util.LinkedList;// Node structureclassNode{intdata;Nodeleft,right;Node(intvalue){data=value;left=right=null;}}classGfG{// Calculate heightstaticintgetHeight(Noderoot,inth){if(root==null)returnh-1;returnMath.max(getHeight(root.left,h+1),getHeight(root.right,h+1));}// Print level OrderstaticvoidlevelOrder(Noderoot){Queue<List<Object>>queue=newLinkedList<>();queue.offer(List.of(root,0));intlastLevel=0;// function to get the height of treeintheight=getHeight(root,0);// printing the level order of treewhile(!queue.isEmpty()){List<Object>top=queue.poll();Nodenode=(Node)top.get(0);intlvl=(int)top.get(1);if(lvl>lastLevel){System.out.println();lastLevel=lvl;}// all levels are printedif(lvl>height)break;// printing null nodeSystem.out.print((node.data==-1?"N":node.data)+" ");// null node has no childrenif(node.data==-1)continue;if(node.left==null)queue.offer(List.of(newNode(-1),lvl+1));elsequeue.offer(List.of(node.left,lvl+1));if(node.right==null)queue.offer(List.of(newNode(-1),lvl+1));elsequeue.offer(List.of(node.right,lvl+1));}}//Driver Code Ends// Function to update the tree staticvoidupdateTree(Noderoot,int[]sum){if(root==null){return;}// Traverse the right subtree first (larger values)updateTree(root.right,sum);// Update the sum and the current node's valuesum[0]+=root.data;root.data=sum[0]-root.data;// Traverse the left subtree (smaller values)updateTree(root.left,sum);}// Return the updated treestaticvoidtransformTree(Noderoot){// Initialize the cumulative sumint[]sum={0};updateTree(root,sum);}//Driver Code Startspublicstaticvoidmain(String[]args){// Constructing the BST// 11// / \// 2 29// / \ / \// 1 7 15 40// \// 50Noderoot=newNode(11);root.left=newNode(2);root.right=newNode(29);root.left.left=newNode(1);root.left.right=newNode(7);root.right.left=newNode(15);root.right.right=newNode(40);root.right.right.left=newNode(35);transformTree(root);levelOrder(root);}}//Driver Code Ends
Python
#Driver Code Startsfromcollectionsimportdeque# Node structureclassNode:def__init__(self,value):self.data=valueself.left=Noneself.right=None# Calculate HeightdefgetHeight(root,h):ifrootisNone:returnh-1returnmax(getHeight(root.left,h+1),getHeight(root.right,h+1))# Print Level OrderdeflevelOrder(root):queue=deque([[root,0]])lastLevel=0# function to get the height of treeheight=getHeight(root,0)# printing the level order of treewhilequeue:node,lvl=queue.popleft()iflvl>lastLevel:print()lastLevel=lvl# all levels are printediflvl>height:break# printing null nodeprint("N"ifnode.data==-1elsenode.data,end=" ")# null node has no childrenifnode.data==-1:continueifnode.leftisNone:queue.append([Node(-1),lvl+1])else:queue.append([node.left,lvl+1])ifnode.rightisNone:queue.append([Node(-1),lvl+1])else:queue.append([node.right,lvl+1])#Driver Code Ends# Function to update the treedefupdateTree(root,sum):ifrootisNone:return# Traverse the right subtree first (larger values)updateTree(root.right,sum)# Update the sum and the current node's valuesum[0]+=root.dataroot.data=sum[0]-root.data# Traverse the left subtree (smaller values)updateTree(root.left,sum)# Return the updated treedeftransformTree(root):# Initialize the cumulative sumsum=[0]updateTree(root,sum)#Driver Code Startsif__name__=="__main__":# Constructing the BST# 11# / \# 2 29# / \ / \# 1 7 15 40# /# 35root=Node(11)root.left=Node(2)root.right=Node(29)root.left.left=Node(1)root.left.right=Node(7)root.right.left=Node(15)root.right.right=Node(40)root.right.right.left=Node(35)transformTree(root)levelOrder(root)#Driver Code Ends
C#
//Driver Code StartsusingSystem;usingSystem.Collections.Generic;// Node structureclassNode{publicintdata;publicNodeleft,right;publicNode(intvalue){data=value;left=right=null;}}classGfG{// Calculate heightstaticintgetHeight(Noderoot,inth){if(root==null)returnh-1;returnMath.Max(getHeight(root.left,h+1),getHeight(root.right,h+1));}// Print Level OrderstaticvoidlevelOrder(Noderoot){Queue<(Node,int)>queue=newQueue<(Node,int)>();queue.Enqueue((root,0));intlastLevel=0;// function to get the height of treeintheight=getHeight(root,0);// printing the level order of treewhile(queue.Count>0){var(node,lvl)=queue.Dequeue();if(lvl>lastLevel){Console.WriteLine();lastLevel=lvl;}// all levels are printedif(lvl>height)break;// printing null nodeConsole.Write((node.data==-1?"N":node.data.ToString())+" ");// null node has no childrenif(node.data==-1)continue;if(node.left==null)queue.Enqueue((newNode(-1),lvl+1));elsequeue.Enqueue((node.left,lvl+1));if(node.right==null)queue.Enqueue((newNode(-1),lvl+1));elsequeue.Enqueue((node.right,lvl+1));}//Driver Code Ends}// Function to update the treestaticvoidupdateTree(Noderoot,refintsum){if(root==null){return;}// Traverse the right subtree first (larger values)updateTree(root.right,refsum);// Update the sum and the current node's valuesum+=root.data;root.data=sum-root.data;// Traverse the left subtree (smaller values)updateTree(root.left,refsum);}// Return the updated treestaticvoidtransformTree(Noderoot){// Initialize the cumulative sumintsum=0;updateTree(root,refsum);}//Driver Code StartsstaticvoidMain(){// Constructing the BST// 11// / \// 2 29// / \ / \// 1 7 15 40// /// 35Noderoot=newNode(11);root.left=newNode(2);root.right=newNode(29);root.left.left=newNode(1);root.left.right=newNode(7);root.right.left=newNode(15);root.right.right=newNode(40);root.right.right.left=newNode(35);transformTree(root);levelOrder(root);}}//Driver Code Ends
JavaScript
//Driver Code Starts// Node structureclassNode{constructor(value){this.data=value;this.left=null;this.right=null;}}// Calculate HeightfunctiongetHeight(root,h){if(root===null)returnh-1;returnMath.max(getHeight(root.left,h+1),getHeight(root.right,h+1));}// Print Level OrderfunctionlevelOrder(root){letqueue=[];queue.push([root,0]);letlastLevel=0;// function to get the height of treeletheight=getHeight(root,0);// printing the level order of treewhile(queue.length>0){let[node,lvl]=queue.shift();if(lvl>lastLevel){console.log("");lastLevel=lvl;}// all levels are printedif(lvl>height)break;// printing null nodeprocess.stdout.write((node.data===-1?"N":node.data)+" ");// null node has no childrenif(node.data===-1)continue;if(node.left===null)queue.push([newNode(-1),lvl+1]);elsequeue.push([node.left,lvl+1]);if(node.right===null)queue.push([newNode(-1),lvl+1]);elsequeue.push([node.right,lvl+1]);}}//Driver Code Ends// Function to update the tree functionupdateTree(root,sum){if(root===null){return;}// Traverse the right subtree first (larger values)updateTree(root.right,sum);// Update the sum and the current node's valuesum[0]+=root.data;root.data=sum[0]-root.data;// Traverse the left subtree (smaller values)updateTree(root.left,sum);}// Return the Updated treefunctiontransformTree(root){// Initialize the cumulative sumletsum=[0];updateTree(root,sum);}//Driver Code Starts// Constructing the BST// 11// / \// 2 29// / \ / \// 1 7 15 40// /// 35letroot=newNode(11);root.left=newNode(2);root.right=newNode(29);root.left.left=newNode(1);root.left.right=newNode(7);root.right.left=newNode(15);root.right.right=newNode(40);root.right.right.left=newNode(35);transformTree(root);levelOrder(root);//Driver Code Ends