Given a binary tree, a target node and a positive integer K on it, the task is to find the sum of all nodes within distance K from the target node (including the value of the target node in the sum).
Examples:
Input: target = 9, K = 1,
Binary Tree = 1
/ \
2 9
/ / \
4 5 7
/ \ / \
8 19 20 11
/ / \
30 40 50
Output: 22
Explanation: Nodes within distance 1 from 9 is 9 + 5 + 7 + 1 = 22Input: target = 40, K = 2,
Binary Tree = 1
/ \
2 9
/ / \
4 5 7
/ \ / \
8 19 20 11
/ / \
30 40 50
Output: 113
Explanation: Nodes within distance 2 from 40 is
40 + 19 + 50 + 4 = 113
Approach: This problem can be solved using hashing and Depth-First-Search based on the following idea:
Use a data structure to store the parent of each node. Now utilise that data structure to perform a DFS traversal from target and calculate the sum of all the nodes within K distance from that node.
Follow the steps mentioned below to implement the approach:
- Create a hash table (say par)to store the parent of each node.
- Perform a DFS and store the parent of each node.
- Now find the target in the tree.
- Create a hash table to mark the visited nodes.
- Start a DFS from target:
- If the distance is not K, add the value in the final sum.
- If the node is not visited then continue the DFS traversal for its neighbours also (i.e. parent and child) with the help of par and the links of each node.
- Return the sum of its neighbours while the recursion for the current node is complete
- Return the sum of all the nodes within K distance from the target.
Below is the implementation of the above approach:
// C++ code to implement above approach
#include <bits/stdc++.h>
using namespace std;
// Structure of a tree node
struct Node {
int data;
Node* left;
Node* right;
Node(int val)
{
this->data = val;
this->left = 0;
this->right = 0;
}
};
// Function for marking the parent node
// for all the nodes using DFS
void dfs(Node* root,
unordered_map<Node*, Node*>& par)
{
if (root == 0)
return;
if (root->left != 0)
par[root->left] = root;
if (root->right != 0)
par[root->right] = root;
dfs(root->left, par);
dfs(root->right, par);
}
// Function calling for finding the sum
void dfs3(Node* root, int h, int& sum, int k,
unordered_map<Node*, int>& vis,
unordered_map<Node*, Node*>& par)
{
if (h == k + 1)
return;
if (root == 0)
return;
if (vis[root])
return;
sum += root->data;
vis[root] = 1;
dfs3(root->left, h + 1, sum, k, vis, par);
dfs3(root->right, h + 1, sum, k, vis, par);
dfs3(par[root], h + 1, sum, k, vis, par);
}
// Function for finding
// the target node in the tree
Node* dfs2(Node* root, int target)
{
if (root == 0)
return 0;
if (root->data == target)
return root;
Node* node1 = dfs2(root->left, target);
Node* node2 = dfs2(root->right, target);
if (node1 != 0)
return node1;
if (node2 != 0)
return node2;
}
// Function to find the sum at distance K
int sum_at_distK(Node* root, int target,
int k)
{
// Hash Table to store
// the parent of a node
unordered_map<Node*, Node*> par;
// Make the parent of root node as NULL
// since it does not have any parent
par[root] = 0;
// Mark the parent node for all the
// nodes using DFS
dfs(root, par);
// Find the target node in the tree
Node* node = dfs2(root, target);
// Hash Table to mark
// the visited nodes
unordered_map<Node*, int> vis;
int sum = 0;
// DFS call to find the sum
dfs3(node, 0, sum, k, vis, par);
return sum;
}
// Driver Code
int main()
{
// Taking Input
Node* root = new Node(1);
root->left = new Node(2);
root->right = new Node(9);
root->left->left = new Node(4);
root->right->left = new Node(5);
root->right->right = new Node(7);
root->left->left->left = new Node(8);
root->left->left->right = new Node(19);
root->right->right->left = new Node(20);
root->right->right->right
= new Node(11);
root->left->left->left->left
= new Node(30);
root->left->left->right->left
= new Node(40);
root->left->left->right->right
= new Node(50);
int target = 9, K = 1;
// Function call
cout << sum_at_distK(root, target, K);
return 0;
}
// Java code to implement above approach
import java.util.*;
public class Main {
// Structure of a tree node
static class Node {
int data;
Node left;
Node right;
Node(int val)
{
this.data = val;
this.left = null;
this.right = null;
}
}
// Function for marking the parent node
// for all the nodes using DFS
static void dfs(Node root,
HashMap <Node, Node> par)
{
if (root == null)
return;
if (root.left != null)
par.put( root.left, root);
if (root.right != null)
par.put( root.right, root);
dfs(root.left, par);
dfs(root.right, par);
}
static int sum;
// Function calling for finding the sum
static void dfs3(Node root, int h, int k,
HashMap <Node, Integer> vis,
HashMap <Node, Node> par)
{
if (h == k + 1)
return;
if (root == null)
return;
if (vis.containsKey(root))
return;
sum += root.data;
vis.put(root, 1);
dfs3(root.left, h + 1, k, vis, par);
dfs3(root.right, h + 1, k, vis, par);
dfs3(par.get(root), h + 1, k, vis, par);
}
// Function for finding
// the target node in the tree
static Node dfs2(Node root, int target)
{
if (root == null)
return null;
if (root.data == target)
return root;
Node node1 = dfs2(root.left, target);
Node node2 = dfs2(root.right, target);
if (node1 != null)
return node1;
if (node2 != null)
return node2;
return null;
}
static int sum_at_distK(Node root, int target,
int k)
{
// Hash Map to store
// the parent of a node
HashMap <Node, Node> par = new HashMap<>();
// Make the parent of root node as NULL
// since it does not have any parent
par.put(root, null);
// Mark the parent node for all the
// nodes using DFS
dfs(root, par);
// Find the target node in the tree
Node node = dfs2(root, target);
// Hash Map to mark
// the visited nodes
HashMap <Node, Integer> vis = new HashMap<>();
sum = 0;
// DFS call to find the sum
dfs3(node, 0, k, vis, par);
return sum;
}
public static void main(String args[]) {
// Taking Input
Node root = new Node(1);
root.left = new Node(2);
root.right = new Node(9);
root.left.left = new Node(4);
root.right.left = new Node(5);
root.right.right = new Node(7);
root.left.left.left = new Node(8);
root.left.left.right = new Node(19);
root.right.right.left = new Node(20);
root.right.right.right
= new Node(11);
root.left.left.left.left
= new Node(30);
root.left.left.right.left
= new Node(40);
root.left.left.right.right
= new Node(50);
int target = 9, K = 1;
// Function call
System.out.println( sum_at_distK(root, target, K) );
}
}
// This code has been contributed by Sachin Sahara (sachin801)
# python program to implement above approach
# structure of tree node
class Node:
def __init__(self, val):
self.data = val
self.left = None
self.right = None
# function for making the parent node
# for all the nodes using DFS
def dfs(root, par):
if(root is None):
return
if(root.left is not None):
par[root.left] = root
if(root.right is not None):
par[root.right] = root
dfs(root.left, par)
dfs(root.right, par)
# function calling for finding the sum
summ = 0
def dfs3(root, h, k, vis, par):
if(h == k+1):
return
if(root is None):
return
if(vis.get(root) == 1):
return
global summ
summ += root.data
vis[root] = 1
dfs3(root.left, h+1, k, vis, par)
dfs3(root.right, h+1, k, vis, par)
dfs3(par[root], h+1, k, vis, par)
# function for finding
# the target node in the tree
def dfs2(root, target):
if(root is None):
return None
if(root.data == target):
return root
node1 = dfs2(root.left, target)
node2 = dfs2(root.right, target)
if(node1 is not None):
return node1
if(node2 is not None):
return node2
# function tofind the sum at distance k
def sum_at_distK(root, target, k):
# hash table to store
# the parent of a node
par = {}
# make the parent of root node as None
# since it does not have any parent
par[root] = 0
# make the parent node for all the
# nodes using DFS
dfs(root, par)
# find the target node in the tree
node = dfs2(root, target)
# hash table to make the visited nodes
vis = {}
# dfs call to find the sum
dfs3(node, 0, k, vis, par)
# driver program
root = Node(1)
root.left = Node(2)
root.right = Node(9)
root.left.left = Node(4)
root.right.left = Node(5)
root.right.right = Node(7)
root.left.left.left = Node(8)
root.left.left.right = Node(19)
root.right.right.left = Node(20)
root.right.right.right = Node(11)
root.left.left.left.left = Node(30)
root.left.left.right.left = Node(40)
root.left.left.right.right = Node(50)
target = 9
K = 1
# function call
sum_at_distK(root, target, K)
print(summ)
# this code is contributed by Yash Agarwal(yashagarwal2852002)
// C# code to implement above approach
using System;
using System.Collections.Generic;
public class GFG {
// Structure of a tree node
class Node {
public int data;
public Node left;
public Node right;
public Node(int val)
{
this.data = val;
this.left = null;
this.right = null;
}
}
// Function for marking the parent node
// for all the nodes using DFS
static void dfs(Node root, Dictionary<Node, Node> par)
{
if (root == null)
return;
if (root.left != null)
par.Add(root.left, root);
if (root.right != null)
par.Add(root.right, root);
dfs(root.left, par);
dfs(root.right, par);
}
static int sum;
// Function calling for finding the sum
static void dfs3(Node root, int h, int k,
Dictionary<Node, int> vis,
Dictionary<Node, Node> par)
{
if (h == k + 1)
return;
if (root == null)
return;
if (vis.ContainsKey(root))
return;
sum += root.data;
vis.Add(root, 1);
dfs3(root.left, h + 1, k, vis, par);
dfs3(root.right, h + 1, k, vis, par);
dfs3(par[root], h + 1, k, vis, par);
}
// Function for finding
// the target node in the tree
static Node dfs2(Node root, int target)
{
if (root == null)
return null;
if (root.data == target)
return root;
Node node1 = dfs2(root.left, target);
Node node2 = dfs2(root.right, target);
if (node1 != null)
return node1;
if (node2 != null)
return node2;
return null;
}
static int sum_at_distK(Node root, int target, int k)
{
// Hash Map to store
// the parent of a node
Dictionary<Node, Node> par
= new Dictionary<Node, Node>();
// Make the parent of root node as NULL
// since it does not have any parent
par.Add(root, null);
// Mark the parent node for all the
// nodes using DFS
dfs(root, par);
// Find the target node in the tree
Node node = dfs2(root, target);
// Hash Map to mark
// the visited nodes
Dictionary<Node, int> vis
= new Dictionary<Node, int>();
sum = 0;
// DFS call to find the sum
dfs3(node, 0, k, vis, par);
return sum;
}
static public void Main()
{
// Code
Node root = new Node(1);
root.left = new Node(2);
root.right = new Node(9);
root.left.left = new Node(4);
root.right.left = new Node(5);
root.right.right = new Node(7);
root.left.left.left = new Node(8);
root.left.left.right = new Node(19);
root.right.right.left = new Node(20);
root.right.right.right = new Node(11);
root.left.left.left.left = new Node(30);
root.left.left.right.left = new Node(40);
root.left.left.right.right = new Node(50);
int target = 9, K = 1;
// Function call
Console.Write(sum_at_distK(root, target, K));
}
}
// This code is contributed by lokesh(lokeshmvs21).
// JavaScript code for the above approach
// Structure of a tree node
class Node {
constructor(val) {
this.data = val;
this.left = null;
this.right = null;
}
}
// Function for marking the parent node
// for all the nodes using DFS
function dfs(root, par) {
if (root === null) return;
if (root.left !== null) par.set(root.left, root);
if (root.right !== null) par.set(root.right, root);
dfs(root.left, par);
dfs(root.right, par);
}
let sum = 0;
// Function calling for finding the sum
function dfs3(root, h, k, vis, par) {
if (h === k + 1) return;
if (root === null) return;
if (vis.has(root)) return;
sum += root.data;
vis.set(root, 1);
dfs3(root.left, h + 1, k, vis, par);
dfs3(root.right, h + 1, k, vis, par);
if (par.get(root) !== null && vis.has(par.get(root))) {
dfs3(par.get(root), h + 1, k, vis, par);
}
}
// Function for finding
// the target node in the tree
function dfs2(root, target) {
if (root === null) return null;
if (root.data === target) return root;
let node1 = dfs2(root.left, target);
let node2 = dfs2(root.right, target);
if (node1 !== null) return node1;
if (node2 !== null) return node2;
return null;
}
function sumAtDistK(root, target, k)
{
// Map to store the parent of a node
let par = new Map();
// Make the parent of root node as NULL
// since it does not have any parent
par.set(root, null);
// Mark the parent node for all the
// nodes using DFS
dfs(root, par);
// Find the target node in the tree
let node = dfs2(root, target);
// Map to mark the visited nodes
let vis = new Map();
sum = 1;
// DFS call to find the sum
dfs3(node, 0, k, vis, par);
return sum;
}
// Taking Input
let root = new Node(1);
root.left = new Node(2);
root.right = new Node(9);
root.left.left = new Node(4);
root.right.left = new Node(5);
root.right.right = new Node(7);
root.left.left.left = new Node(8);
root.left.left.right = new Node(19);
root.right.right.left = new Node(20);
root.right.right.right = new Node(11);
root.left.left.left.left = new Node(30);
root.left.left.right.left = new Node(40);
root.left.left.right.right = new Node(50);
let target = 9;
let K = 1;
console.log(sumAtDistK(root, target, K));
// This code is contributed by Potta Lokesh
Output
22
Time Complexity: O(N) where N is the number of nodes in the tree
Auxiliary Space: O(N)
Approach using BFS:-
- We will be using level order traversal to find the sum of nodes
Implementation:-
- First we will find the target node using level order traversal.
- While finding the target node we will store the parent of each node so that we can move towards the parent of the node as well.
- After this we will traverse from the target node to all the tree directions that is toward both child and parent till distance K and add the values of node into our answer.
// C++ code to implement above approach
#include <bits/stdc++.h>
using namespace std;
// Structure of a tree node
struct Node {
int data;
Node* left;
Node* right;
Node(int val)
{
this->data = val;
this->left = 0;
this->right = 0;
}
};
// Function to find the sum at distance K
int sum_at_distK(Node* root, int target,
int k)
{
//variable to store answer
int ans = 0;
//queue for bfs
queue<Node*> q;
q.push(root);
//to store target node
Node* need;
//map to store parent of each node
unordered_map<Node*, Node*> m;
//bfs
while(q.size()){
int s = q.size();
//traversing to current level
for(int i=0;i<s;i++){
Node* temp = q.front();
q.pop();
//if target value found
if(temp->data==target) need=temp;
if(temp->left){
q.push(temp->left);
m[temp->left]=temp;
}
if(temp->right){
q.push(temp->right);
m[temp->right]=temp;
}
}
}
//map to store occurrence of a node
//that is the node has taken or not
unordered_map<Node*, int> mm;
q.push(need);
//to store current distance
int c = 0;
while(q.size()){
int s = q.size();
for(int i=0;i<s;i++){
Node* temp = q.front();
q.pop();
mm[temp] = 1;
ans+=temp->data;
//moving left
if(temp->left&&mm[temp->left]==0){
q.push(temp->left);
}
//moving right
if(temp->right&&mm[temp->right]==0){
q.push(temp->right);
}
//movinf to parent
if(m[temp]&&mm[m[temp]]==0){
q.push(m[temp]);
}
}
c++;
if(c>k)break;
}
return ans;
}
// Driver Code
int main()
{
// Taking Input
Node* root = new Node(1);
root->left = new Node(2);
root->right = new Node(9);
root->left->left = new Node(4);
root->right->left = new Node(5);
root->right->right = new Node(7);
root->left->left->left = new Node(8);
root->left->left->right = new Node(19);
root->right->right->left = new Node(20);
root->right->right->right
= new Node(11);
root->left->left->left->left
= new Node(30);
root->left->left->right->left
= new Node(40);
root->left->left->right->right
= new Node(50);
int target = 9, K = 1;
// Function call
cout << sum_at_distK(root, target, K);
return 0;
}
//code contributed by shubhamrajput6156
import java.util.*;
// Structure of a tree node
class Node {
int data;
Node left;
Node right;
public Node(int val) {
this.data = val;
this.left = null;
this.right = null;
}
}
public class Main {
// Function to find the sum at distance K
public static int sumAtDistK(Node root, int target, int k) {
// Variable to store the answer
int ans = 0;
// Queue for BFS
Queue<Node> q = new LinkedList<>();
q.add(root);
// To store the target node
Node need = null;
// Map to store the parent of each node
Map<Node, Node> parentMap = new HashMap<>();
// BFS
while (!q.isEmpty()) {
int size = q.size();
// Traverse the current level
for (int i = 0; i < size; i++) {
Node temp = q.poll();
// If the target value is found
if (temp.data == target) {
need = temp;
}
if (temp.left != null) {
q.add(temp.left);
parentMap.put(temp.left, temp);
}
if (temp.right != null) {
q.add(temp.right);
parentMap.put(temp.right, temp);
}
}
}
// Map to store the occurrence of a node (whether it has been visited)
Map<Node, Integer> visitedMap = new HashMap<>();
q.add(need);
// Current distance
int currentDistance = 0;
while (!q.isEmpty()) {
int size = q.size();
for (int i = 0; i < size; i++) {
Node temp = q.poll();
visitedMap.put(temp, 1);
ans += temp.data;
// Moving left
if (temp.left != null && visitedMap.getOrDefault(temp.left, 0) == 0) {
q.add(temp.left);
}
// Moving right
if (temp.right != null && visitedMap.getOrDefault(temp.right, 0) == 0) {
q.add(temp.right);
}
// Moving to parent
if (parentMap.containsKey(temp) && visitedMap.getOrDefault(parentMap.get(temp), 0) == 0) {
q.add(parentMap.get(temp));
}
}
currentDistance++;
if (currentDistance > k) {
break;
}
}
return ans;
}
// Driver code
public static void main(String[] args) {
Node root = new Node(1);
root.left = new Node(2);
root.right = new Node(9);
root.left.left = new Node(4);
root.right.left = new Node(5);
root.right.right = new Node(7);
root.left.left.left = new Node(8);
root.left.left.right = new Node(19);
root.right.right.left = new Node(20);
root.right.right.right = new Node(11);
root.left.left.left.left = new Node(30);
root.left.left.right.left = new Node(40);
root.left.left.right.right = new Node(50);
int target = 9, K = 1;
// Function call
System.out.println(sumAtDistK(root, target, K));
}
}
from collections import deque
class Node:
def __init__(self, val):
self.data = val
self.left = None
self.right = None
# Function to find the sum at distance K
def sum_at_distK(root, target, k):
ans = 0
# Queue for BFS
q = deque()
q.append(root)
need = None
# Dictionary to store parent of each node
m = {}
# BFS traversal to find the target node
while q:
s = len(q)
# Traversing the current level
for i in range(s):
temp = q.popleft()
if temp.data == target:
need = temp
if temp.left:
q.append(temp.left)
m[temp.left] = temp
if temp.right:
q.append(temp.right)
m[temp.right] = temp
# Dictionary to store occurrence of a node (visited or not)
mm = {}
q.append(need)
c = 0
# BFS traversal within K distance
while q:
s = len(q)
for i in range(s):
temp = q.popleft()
mm[temp] = 1
ans += temp.data
# Moving left
if temp.left and temp.left not in mm:
q.append(temp.left)
# Moving right
if temp.right and temp.right not in mm:
q.append(temp.right)
# Moving to parent
if temp in m and m[temp] not in mm:
q.append(m[temp])
c += 1
if c > k:
break
return ans
# Driver Code
# Taking Input
root = Node(1)
root.left = Node(2)
root.right = Node(9)
root.left.left = Node(4)
root.right.left = Node(5)
root.right.right = Node(7)
root.left.left.left = Node(8)
root.left.left.right = Node(19)
root.right.right.left = Node(20)
root.right.right.right = Node(11)
root.left.left.left.left = Node(30)
root.left.left.right.left = Node(40)
root.left.left.right.right = Node(50)
target = 9
K = 1
# Function call
print(sum_at_distK(root, target, K))
using System;
using System.Collections.Generic;
// Structure of a tree node
class Node
{
public int data;
public Node left;
public Node right;
public Node(int val)
{
this.data = val;
this.left = null;
this.right = null;
}
}
class GFG
{
// Function to find the sum at distance K
public static int SumAtDistK(Node root, int target, int k)
{
// Variable to store the answer
int ans = 0;
// Queue for BFS
Queue<Node> q = new Queue<Node>();
q.Enqueue(root);
// To store the target node
Node need = null;
// Dictionary to store the parent of each node
Dictionary<Node, Node> parentMap = new Dictionary<Node, Node>();
// BFS
while (q.Count > 0)
{
int size = q.Count;
// Traverse the current level
for (int i = 0; i < size; i++)
{
Node temp = q.Dequeue();
// If the target value is found
if (temp.data == target)
{
need = temp;
}
if (temp.left != null)
{
q.Enqueue(temp.left);
parentMap[temp.left] = temp;
}
if (temp.right != null)
{
q.Enqueue(temp.right);
parentMap[temp.right] = temp;
}
}
}
// Dictionary to store the occurrence of a node (whether it has been visited)
Dictionary<Node, int> visitedMap = new Dictionary<Node, int>();
q.Enqueue(need);
// Current distance
int currentDistance = 0;
while (q.Count > 0)
{
int size = q.Count;
for (int i = 0; i < size; i++)
{
Node temp = q.Dequeue();
visitedMap[temp] = 1;
ans += temp.data;
// Moving left
if (temp.left != null && visitedMap.GetValueOrDefault(temp.left, 0) == 0)
{
q.Enqueue(temp.left);
}
// Moving right
if (temp.right != null && visitedMap.GetValueOrDefault(temp.right, 0) == 0)
{
q.Enqueue(temp.right);
}
// Moving to parent
if (parentMap.ContainsKey(temp) && visitedMap.GetValueOrDefault(parentMap[temp], 0) == 0)
{
q.Enqueue(parentMap[temp]);
}
}
currentDistance++;
if (currentDistance > k)
{
break;
}
}
return ans;
}
// Driver code
public static void Main(string[] args)
{
Node root = new Node(1);
root.left = new Node(2);
root.right = new Node(9);
root.left.left = new Node(4);
root.right.left = new Node(5);
root.right.right = new Node(7);
root.left.left.left = new Node(8);
root.left.left.right = new Node(19);
root.right.right.left = new Node(20);
root.right.right.right = new Node(11);
root.left.left.left.left = new Node(30);
root.left.left.right.left = new Node(40);
root.left.left.right.right = new Node(50);
int target = 9, K = 1;
// Function call
Console.WriteLine(SumAtDistK(root, target, K));
}
}
class Node {
constructor(val) {
this.data = val;
this.left = null;
this.right = null;
}
}
// Function to find the sum at distance K
function sum_at_distK(root, target, k) {
let ans = 0;
// Queue for BFS
let q = [];
q.push(root);
let need = null;
// Map to store parent of each node
let m = new Map();
// BFS traversal to find the target node
while (q.length) {
let s = q.length;
// Traversing the current level
for (let i = 0; i < s; i++) {
let temp = q.shift();
if (temp.data === target) {
need = temp;
}
if (temp.left) {
q.push(temp.left);
m.set(temp.left, temp);
}
if (temp.right) {
q.push(temp.right);
m.set(temp.right, temp);
}
}
}
// Map to store occurrence of a node (visited or not)
let mm = new Map();
q.push(need);
let c = 0;
// BFS traversal within K distance
while (q.length) {
let s = q.length;
for (let i = 0; i < s; i++) {
let temp = q.shift();
mm.set(temp, 1);
ans += temp.data;
// Moving left
if (temp.left && !mm.has(temp.left)) {
q.push(temp.left);
}
// Moving right
if (temp.right && !mm.has(temp.right)) {
q.push(temp.right);
}
// Moving to parent
if (m.has(temp) && !mm.has(m.get(temp))) {
q.push(m.get(temp));
}
}
c++;
if (c > k) break;
}
return ans;
}
// Driver Code
// Taking Input
let root = new Node(1);
root.left = new Node(2);
root.right = new Node(9);
root.left.left = new Node(4);
root.right.left = new Node(5);
root.right.right = new Node(7);
root.left.left.left = new Node(8);
root.left.left.right = new Node(19);
root.right.right.left = new Node(20);
root.right.right.right = new Node(11);
root.left.left.left.left = new Node(30);
root.left.left.right.left = new Node(40);
root.left.left.right.right = new Node(50);
let target = 9, K = 1;
// Function call
console.log(sum_at_distK(root, target, K));
Output
22
Time Complexity:- O(N) Where N is the number of nodes
Auxiliary Space:- O(N)