Amazon Interview Question
Country: India
The idea is good. However, nobody described a way of doing next inorder traversal and the reverse inorder traversal in O(1) space. Even the one mentioned in geeksforgeeks is NOT 0(1).
public void findPairs(TreeNode root, int target)
{
dfs(root);
while(head!=tail && tail.right!=head)
{
if(head.val + tail.val == target)
{
System.out.println("[" + head.val + "," + tail.val + "]");
head = head.right;
tail = tail.left;
}
else if(head.val + tail.val > target)
{
tail = tail.left;
}
else head = head.right;
}
}
private void dfs(TreeNode root)
{
if(root == null) return;
dfs(root.left);
if(pre!=null)
{
pre.right = root;
root.left = pre;
}
else
{
head = root;
}
pre = root;
tail = root;
dfs(root.right);
}
Convert the tree to sorted doubly linked list. Then follow the same procedure as in case of array.
Can you explain how a sorted doubly linked list will use constant memory? If there are n nodes won't the usage be O(n)?
an implementation based on comment of Nascent:
import java.util.*;
public class Solution {
class TreeNode {
private TreeNode leftNode;
private TreeNode rightNode;
private int value;
public TreeNode(int value) {
super();
this.value = value;
leftNode = null;
rightNode = null;
}
public TreeNode getLeftNode() {
return leftNode;
}
public void setLeftNode(TreeNode leftNode) {
this.leftNode = leftNode;
}
public TreeNode getRightNode() {
return rightNode;
}
public void setRightNode(TreeNode rightNode) {
this.rightNode = rightNode;
}
public int getValue() {
return value;
}
public void setValue(int value) {
this.value = value;
}
public void addNode(int val) {
if (value == val) {
System.out.println("Error duplicate element: " + value + " = " + val);
}
if (value > val) {
if (leftNode != null) {
leftNode.addNode(val);
} else {
leftNode = new TreeNode(val);
}
}
if (value < val) {
if (rightNode != null) {
rightNode.addNode(val);
} else {
rightNode = new TreeNode(val);
}
}
}
}
private static TreeNode btreeToListUtil(TreeNode root){
if( root == null) {
return null;
}
if(root.getLeftNode() != null){
TreeNode leftTreeNode = btreeToListUtil(root.getLeftNode());
while(leftTreeNode.getRightNode() != null){
leftTreeNode = leftTreeNode.getRightNode();
}
leftTreeNode.setRightNode(root);
root.setLeftNode(leftTreeNode);
}
if(root.getRightNode() != null){
TreeNode rightTreeNode = btreeToListUtil(root.getRightNode());
while(rightTreeNode.getLeftNode() != null){
rightTreeNode = rightTreeNode.getLeftNode();
}
rightTreeNode.setLeftNode(root);
root.setRightNode(rightTreeNode);
}
return root;
}
public static TreeNode btreeToList(TreeNode root){
TreeNode head = btreeToListUtil(root);
while(head.getLeftNode() != null){
head = head.getLeftNode();
}
return head;
}
public static List<Integer> btreeToValueList(TreeNode root){
TreeNode head = btreeToList(root);
List<Integer> returnList = new ArrayList<Integer>();
returnList.add(head.getValue());
while(head.getRightNode() != null){
head = head.getRightNode();
returnList.add(head.getValue());
}
return returnList;
}
public static void printPairsSumK(List<Integer> vals, int k) {
int start = 0;
int end = vals.size() -1;
int sum = 0;
while (start < end) {
sum = vals.get(start) + vals.get(end);
if (sum == k) {
System.out.println(vals.get(start) + " + " + vals.get(end) + " = " + k);
start++;
} else if (sum < k) {
start++;
} else if (sum > k) {
end--;
}
}
}
public void runSolution(TreeNode root, int sum) {
// 1. Convert Binary Tree to list
List<Integer> valueList = btreeToValueList(root);
// 2. print pairs which sum up to K
printPairsSumK(valueList, sum);
}
public void runTests() {
TreeNode root = new TreeNode(1);
root.addNode(2);
root.addNode(3);
root.addNode(4);
root.addNode(5);
root.addNode(6);
root.addNode(7);
root.addNode(8);
int sum = 6;
// should print 1 + 5 = 6, 2 + 4 = 6
runSolution(root, sum);
TreeNode root2 = new TreeNode(2);
root2.addNode(34);
root2.addNode(56);
root2.addNode(23);
root2.addNode(12);
root2.addNode(6);
root2.addNode(78);
root2.addNode(33);
root2.addNode(11);
root2.addNode(20);
root2.addNode(24);
root2.addNode(14);
root2.addNode(30);
int sum2 = 44;
// should print 33 + 11 = 44, 20 + 24 = 44, 14 + 30 = 44
runSolution(root2, sum2);
TreeNode root3 = new TreeNode(3);
root3.addNode(-1230);
root3.addNode(2230);
root3.addNode(-500);
root3.addNode(1500);
root3.addNode(6500);
root3.addNode(8000);
root3.addNode(8100);
root3.addNode(500);
root3.addNode(2510);
root3.addNode(1510);
root3.addNode(490);
root3.addNode(510);
root3.addNode(8750);
root3.addNode(-6750);
int sum3 = 1000;
// should print -1230 + 2230 = 1000, -500 + 1500 = 1000, 490 + 510 = 1000
runSolution(root3, sum3);
}
public static void main(String args[]) {
Solution sol = new Solution();
sol.runTests();
}
}
If we do a tree traversal, and keep on adding numbers to a hash. Then in next traversal, use hash key as ( k - node value). If value exists in hash table that means you have a pair. Otherwise don't.
import java.util.*;
import java.lang.*;
import java.io.*;
class Main
{
static class Node
{
Node(Node left, Node right, Node parent, int key)
{
this.left = left;
this.right = right;
this.parent = parent;
this.key = key;
}
public Node left;
public Node right;
public Node parent;
public int key;
}
public static void inOrder (Node node)
{
if(node.left != null) inOrder(node.left);
System.out.println(node.key);
if(node.right != null) inOrder(node.right);
}
public static Node successor (Node node)
{
Node successor = null;
if(node.right != null)
{
node = node.right;
while(node.left != null)
{
node = node.left;
}
successor = node;
}
else if(node.parent != null && node.parent.left == node) successor = node.parent;
else if(node.parent != null && node.parent.right == node)
{
while(node.parent != null && node.parent.right == node)
{
node = node.parent;
}
successor = node.parent;
}
return successor;
}
public static Node predeccessor(Node node)
{
Node predecessor = null;
if(node.left != null)
{
node = node.left;
while(node.right != null)
{
node = node.right;
}
predecessor = node;
}
else if(node.parent != null && node.parent.right == node) predecessor = node.parent;
else if(node.parent != null && node.parent.left == node)
{
while(node.parent != null && node.parent.left == node)
{
node = node.parent;
}
predecessor = node.parent;
}
return predecessor;
}
public static Node min (Node node)
{
while(node.left != null) node = node.left;
return node;
}
public static Node max (Node node)
{
while(node.right != null) node = node.right;
return node;
}
public static void main (String[] args) throws java.lang.Exception
{
Node root = new Node(null, null, null, 5);// 5
Node a = new Node(null, null, null, 2); // / \
Node b = new Node(null, null, null, 1); // 2 8
Node c = new Node(null, null, null, 3); // /\ \
Node d = new Node(null, null, null, 8); // 1 3 9
Node e = new Node(null, null, null, 9); //
root.left = a;
root.right = d;
a.parent = root;
d.parent = root;
a.left = b;
a.right = c;
b.parent = a;
c.parent = a;
d.right = e;
e.parent = d;
int sum = 10;
//find pairs in Binary Search Tree in which sum of nodes keys is equal 10
Node minnode = min(root);
Node maxnode = max(root);
while(minnode.key <= maxnode.key && minnode != maxnode)
{
if (minnode.key + maxnode.key == sum)
{
System.out.println("Pair found: " + minnode.key + " " + maxnode.key);
minnode = successor(minnode);
maxnode = predeccessor(maxnode);
}
else if(minnode.key + maxnode.key < sum)
{
minnode = successor(minnode);
}
else
{
maxnode = predeccessor(maxnode);
}
}
}
}
import java.util.*;
import java.lang.*;
import java.io.*;
class Main
{
static class Node
{
Node(Node left, Node right, Node parent, int key)
{
this.left = left;
this.right = right;
this.parent = parent;
this.key = key;
}
public Node left;
public Node right;
public Node parent;
public int key;
}
public static void inOrder (Node node)
{
if(node.left != null) inOrder(node.left);
System.out.println(node.key);
if(node.right != null) inOrder(node.right);
}
public static Node successor (Node node)
{
Node successor = null;
if(node.right != null)
{
node = node.right;
while(node.left != null)
{
node = node.left;
}
successor = node;
}
else if(node.parent != null && node.parent.left == node) successor = node.parent;
else if(node.parent != null && node.parent.right == node)
{
while(node.parent != null && node.parent.right == node)
{
node = node.parent;
}
successor = node.parent;
}
return successor;
}
public static Node predeccessor(Node node)
{
Node predecessor = null;
if(node.left != null)
{
node = node.left;
while(node.right != null)
{
node = node.right;
}
predecessor = node;
}
else if(node.parent != null && node.parent.right == node) predecessor = node.parent;
else if(node.parent != null && node.parent.left == node)
{
while(node.parent != null && node.parent.left == node)
{
node = node.parent;
}
predecessor = node.parent;
}
return predecessor;
}
public static Node min (Node node)
{
while(node.left != null) node = node.left;
return node;
}
public static Node max (Node node)
{
while(node.right != null) node = node.right;
return node;
}
public static void main (String[] args) throws java.lang.Exception
{
Node root = new Node(null, null, null, 5);// 5
Node a = new Node(null, null, null, 2); // / \
Node b = new Node(null, null, null, 1); // 2 8
Node c = new Node(null, null, null, 3); // /\ \
Node d = new Node(null, null, null, 8); // 1 3 9
Node e = new Node(null, null, null, 9); //
root.left = a;
root.right = d;
a.parent = root;
d.parent = root;
a.left = b;
a.right = c;
b.parent = a;
c.parent = a;
d.right = e;
e.parent = d;
int sum = 10;
//find pairs in Binary Search Tree in which sum of nodes keys is equal 10
Node minnode = min(root);
Node maxnode = max(root);
while(minnode.key <= maxnode.key && minnode != maxnode)
{
if (minnode.key + maxnode.key == sum)
{
System.out.println("Pair found: " + minnode.key + " " + maxnode.key);
minnode = successor(minnode);
maxnode = predeccessor(maxnode);
}
else if(minnode.key + maxnode.key < sum)
{
minnode = successor(minnode);
}
else
{
maxnode = predeccessor(maxnode);
}
}
}
}
import java.util.*;
import java.lang.*;
import java.io.*;
class Main
{
static class Node
{
Node(Node left, Node right, Node parent, int key)
{
this.left = left;
this.right = right;
this.parent = parent;
this.key = key;
}
public Node left;
public Node right;
public Node parent;
public int key;
}
public static void inOrder (Node node)
{
if(node.left != null) inOrder(node.left);
System.out.println(node.key);
if(node.right != null) inOrder(node.right);
}
public static Node successor (Node node)
{
Node successor = null;
if(node.right != null)
{
node = node.right;
while(node.left != null)
{
node = node.left;
}
successor = node;
}
else if(node.parent != null && node.parent.left == node) successor = node.parent;
else if(node.parent != null && node.parent.right == node)
{
while(node.parent != null && node.parent.right == node)
{
node = node.parent;
}
successor = node.parent;
}
return successor;
}
public static Node predeccessor(Node node)
{
Node predecessor = null;
if(node.left != null)
{
node = node.left;
while(node.right != null)
{
node = node.right;
}
predecessor = node;
}
else if(node.parent != null && node.parent.right == node) predecessor = node.parent;
else if(node.parent != null && node.parent.left == node)
{
while(node.parent != null && node.parent.left == node)
{
node = node.parent;
}
predecessor = node.parent;
}
return predecessor;
}
public static Node min (Node node)
{
while(node.left != null) node = node.left;
return node;
}
public static Node max (Node node)
{
while(node.right != null) node = node.right;
return node;
}
public static void main (String[] args) throws java.lang.Exception
{
Node root = new Node(null, null, null, 5);// 5
Node a = new Node(null, null, null, 2); // / \
Node b = new Node(null, null, null, 1); // 2 8
Node c = new Node(null, null, null, 3); // /\ \
Node d = new Node(null, null, null, 8); // 1 3 9
Node e = new Node(null, null, null, 9); //
root.left = a;
root.right = d;
a.parent = root;
d.parent = root;
a.left = b;
a.right = c;
b.parent = a;
c.parent = a;
d.right = e;
e.parent = d;
int sum = 10;
//find pairs in Binary Search Tree in which sum of nodes keys is equal 10
Node minnode = min(root);
Node maxnode = max(root);
while(minnode.key <= maxnode.key && minnode != maxnode)
{
if (minnode.key + maxnode.key == sum)
{
System.out.println("Pair found: " + minnode.key + " " + maxnode.key);
minnode = successor(minnode);
maxnode = predeccessor(maxnode);
}
else if(minnode.key + maxnode.key < sum)
{
minnode = successor(minnode);
}
else
{
maxnode = predeccessor(maxnode);
}
}
}
}
inorder traversal from left and reverse inorder from right as discussed in beginning of the thread
public void findSumPairToK(Node node, int k){
Stack leftStack = new Stack();
Stack rightStack = new Stack();
for(Node curr = node; curr != null; leftStack.push(curr),curr = curr.left);
for(Node curr = node; curr != null; rightStack.push(curr),curr = curr.right);
for(Node left = leftStack.pop(), right = rightStack.pop(); left.data <= right.data;){
if(left.data + right.data < k){
left = leftStack.pop();
for(AVLTree.Node curr = left.right; curr != null; leftStack.push(curr), curr = curr.left);
}
else if(left.data + right.data > k){
right = rightStack.pop();
for(Node curr = right.left; curr != null; rightStack.push(curr), curr = curr.right);
}
else{
System.out.println("Pair summing to K is: " + left.data + " " + right.data);
left = leftStack.pop();
for(Node curr = left.right; curr != null; leftStack.push(curr), curr = curr.left);
right = rightStack.pop();
for(Node curr = right.left; curr != null; rightStack.push(curr), curr = curr.right);
}
}
}
inorder traversal from left and reverse inorder from right as discussed in beginning of the thread
public void findSumPairToK(Node node, int k){
Stack leftStack = new Stack();
Stack rightStack = new Stack();
for(Node curr = node; curr != null; leftStack.push(curr),curr = curr.left);
for(Node curr = node; curr != null; rightStack.push(curr),curr = curr.right);
for(Node left = leftStack.pop(), right = rightStack.pop(); left.data <= right.data;){
if(left.data + right.data < k){
left = leftStack.pop();
for(AVLTree.Node curr = left.right; curr != null; leftStack.push(curr), curr = curr.left);
}
else if(left.data + right.data > k){
right = rightStack.pop();
for(Node curr = right.left; curr != null; rightStack.push(curr), curr = curr.right);
}
else{
System.out.println("Pair summing to K is: " + left.data + " " + right.data);
left = leftStack.pop();
for(Node curr = left.right; curr != null; leftStack.push(curr), curr = curr.left);
right = rightStack.pop();
for(Node curr = right.left; curr != null; rightStack.push(curr), curr = curr.right);
}
}
}
They only way to do this in O(n) with O(1) space is to solve the next inorder traversal and the reverse inorder traversal. One is the inverse of the other. Because the BST is already sorted, you start the algo off with the left most child and the right most child and see if they sum up to the value. If they sum up, find next inorder of left, and reverse inorder of right. If sum is less then sum, advance left and advance right if it's greater than sum. Do this until left node data is greater than right node data. Voila!
- Anonymous February 15, 2014