分割された2分木の最大積

8022 ワード

問題声明
二分木の根を与えられて、部分木の和の積が最大になるように、1つの端を取り除くことによって、二分木を2つの副木に分割します.
つの部分木の合計の積を返します.答えが大きすぎることがあるので、それをモジュロ109 + 7で返します.
あなたはそれを取った後、modを取る前に答えを最大限にする必要があることに注意してください.
Link to the question on Leetcode
例: 1

Input: root = [1,2,3,4,5,6]
Output: 110
Explanation: Remove the red edge and get 2 binary trees with sum 11 and 10. Their product is 110 (11*10)



アプローチ

The question says that we have to split the binary tree to sub tree where the product of sum of two subtrees are maximized. Splitting the binary tree means removing one edge from binary tree.

We can solve this problem by pre calculating the sum of subtree for each node which will help us in finding the maximum product of sum of sub trees.

So the idea here is we have to find out the maximum value of subtree multiplied with total tree sum minus subtree sum which gives the product of sum of subtrees.

Lets solve this with example [2,3,9,10,7,8,6,5,4,11,1]. This is the given tree. Below is the tree representation of above input.



First lets calculate the sum of subtree for each node. Sum of all the subtrees are as below:



Now lets try to remove the edge between root node "2" and its left child "3" , we get two subtrees . we can get the sum of two subtrees directly from our pre calculated values . The sum of two subtrees can be calculated as in below image.



We can directly get the sum of subtree1 from our pre calculated values and sum of other subtree will total tree sum minus subtree1 sum. Find the product of these two subtrees.

In the same way we have to try removing all the edges one by one and find the maximum product of sum of subtrees.

Note:To pre calculate all the subtrees sum use post order traversal of binary tree and store all the values in a dictionary.


コード
import math

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def maxProduct(self, root: Optional[TreeNode]) -> int:
        if root is None:
            return 0
        subtree_sum = dict()
        self.subtree_sum(root,subtree_sum)
        root_subtree_sum = subtree_sum[root]
        maximum = 0
        for subtree in subtree_sum:
            maximum = max(maximum , subtree_sum[subtree] * (root_subtree_sum - subtree_sum[subtree]))
        return int(maximum%(math.pow(10,9)+7))

    def subtree_sum(self,root,subtree_sum):
        if root is None:
            return
        self.subtree_sum(root.left,subtree_sum)
        self.subtree_sum(root.right,subtree_sum)
        left_sum = 0
        right_sum = 0
        if root.left:
            left_sum = subtree_sum.get(root.left)
        if root.right:
            right_sum = subtree_sum.get(root.right)
        subtree_sum[root] = left_sum + right_sum + root.val