알고리즘 (PS)/BOJ

[백준] 2213 - 트리의 독립집합 (Python)

에버듀 2024. 11. 5. 11:12
반응형

 

https://www.acmicpc.net/problem/2213

 

트리를 구성하는 정점에 가중치가 있을 때, 인접하지 않은 정점들로만 구성된 정점의 부분집합에 대해 가중치의 합이 최대가 되는 부분 집합을 구하는 문제이다. 그에 더해 이 문제는 이 가중치의 합의 최댓값에 더해, 그 값이 나오도록 하는 부분 집합의 원소까지 구해야 한다.

 

이 문제는 트리 DP 로 유명하다.

트리는 그 형태가 이미 재귀적이기 때문에 DP 로 표현하기가 매우 좋다.

 

이 문제의 DP 테이블을 다음과 같이 정의해보자.

 

DP[i][j] = i 번째 노드를 루트로 하는 트리에 대해, 이 노드를 정점에 포함 (j = 0) 또는 포함하지 않을 때 (j = 1) 의 가중치 합의 최댓값

 

이제 점화식을 세워보면

 

만약 i 번째 노드를 계산에 포함한다면, 이 노드의 직속 자식 노드들은 모두 사용할 수 없다.

따라서 DP[i][0] = sum (DP[k][1]) + v[i] (k는 모든 자식 노드의 정점 번호) 이다.

 

만약 i 번째 노드를 계산에 포함시키지 않는다면, 이 노드의 직속 자식 노드들은 모두 포함하거나 포함하지 않을 수 있다.

각각의 경우에 대해 최댓값을 취하면 되므로

DP[i][1] = sum( max(DP[k][0], DP[k][1]) ) (k는 모든 자식 노드들의 정점 번호) 이다.

 

이 과정에서 DP 값을 중복해서 구하는 경우가 발생하므로 메모이제이션을 통해 최적화를 해주면 된다.

 


 

이 문제의 답을 구할 때는 가중치 합의 최대 뿐만 아니라, 그 최댓값을 구성하는 정점들의 정보도 알아야 한다.

역추적 방법을 고민하기 위해 리프노드에서부터 위로 올라가는 방법을 떠올려보자.

 

리프노드 2개를 자식으로 갖는 노드가 있다고 해보자.

이 노드에 대해 DP[i][0] 을 구하는 경우, 리프노드 2개를 모두 사용하지 않고 자신만을 사용하므로, 자신의 정점 번호만 집합에 추가하면 된다.

 

이 노드에 대해 DP[i][1] 을 구하는 경우, 각 리프노드 2개를 포함할 지 말지 결정해서, 최댓값이 포함하는 경우라면 리프노드 정점 번호를 포함하고, 최댓값이 포함하지 않는 경우라면 리프노드 정점 번호를 포함하지 않도록 해서 그 정보를 부모 노드로 올려보낸 뒤, 부모 노드는 그 정보를 합쳐서 그대로 자신의 정점 집합으로 쓰면 된다. 자기 자신읜 포함하지 않기 때문이다.

 

이 과정을 위로 갈 수록 재귀적으로 구현하면 된다.

 


import sys
input = sys.stdin.readline
sys.setrecursionlimit(10**5)

USE, NOT_USE = 0, 1

def solve(node, used):
    if dp[node][used][0] != 0:
        return dp[node][used]

    if used == USE:
        ans = weights[node]
        combi = [node]
        for conn_node in graph[node]:
            if visit[conn_node]:
                continue
            visit[conn_node] = True
            _ans, _combi = solve(conn_node, NOT_USE)
            ans += _ans
            combi += _combi

    else:
        ans = 0
        combi = []
        for conn_node in graph[node]:
            if visit[conn_node]:
                continue
            visit[conn_node] = True
            _ans1, _combi1 = solve(conn_node, USE)
            visit[conn_node] = True
            _ans2, _combi2 = solve(conn_node, NOT_USE)
            if _ans1 > _ans2:
                ans += _ans1
                combi += _combi1
            else:
                ans += _ans2
                combi += _combi2

    dp[node][used] = (ans, tuple(combi))
    visit[node] = False
    return dp[node][used]



n = int(input())
graph = [[] for _ in range(n+1)]
weights = [0] + list(map(int, input().split()))
visit = [False]*(n+1)
for _ in range(n-1):
    s, e = map(int, input().split())
    graph[s].append(e)
    graph[e].append(s)

dp = [[(0, tuple()), (0, tuple())] for _ in range(n+1)]
visit[1] = True
answer1, combi1 = solve(1, USE)
visit[1] = True
answer2, combi2 = solve(1, NOT_USE)
# 
# print(answer1, combi1)
# print(answer2, combi2)
if answer1 > answer2:
    print(answer1)
    print(*sorted(combi1))
else:
    print(answer2)
    print(*sorted(combi2))

 

코드 구현은 위와 같다.

solve 함수는 정점 번호와, 사용 여부를 받아서 그 상태의 최대 가중치 합의 값과 그 때의 노드 구성을 반환하는 함수이다.

반응형