알고리즘 (PS)/BOJ

[백준] 2533 - 사회망 서비스(SNS) (Python)

에버듀 2024. 11. 5. 16:25
반응형

https://www.acmicpc.net/problem/2533

 

트리 DP 문제

DP 테이블을 다음과 같이 정의한다.

 

dp[i][j] = i 번째 노드가 루트인 트리에 대해 이 노드가 얼리어답터 인지 (j = 0) 아닌지 (j = 1) 에 따른 그 트리에서의 얼리어답터 최소 수

 

점화식은 다음과 같이 쓸 수 있다.

 

dp[i][0] = i 번째 노드가 루트인 트리에 대해 이 노드가 얼리어답터인 경우, 모든 자식노드는 얼리어답터여도 되고, 아니어도 된다. 두 경우를 모두 구해서 최소 값을 취한다. 따라서 sum( min( dp[k][0], dp[k][1] ) ) + 1, ( k 는 직접 연결된 자식 노드들의 번호 )

 

dp[i][1] = i 번째 노드가 루트인 트리에 대해, 이 노드가 얼리어답터가 아닌 경우, 모든 자식 노드는 얼리어답터여야 한다.

(자신과 연결된 모든 친구가 얼리어답터여야 받아들이기 때문이다. 부모 노드는 고려하지 않아도 괜찮다. dp[i][1] 을 요구하는 경우는 부모노드가 얼리어답터인 경우에만 요구하기 때문이다.)

따라서 sum( dp[k][0] ) (이때 k는 직접 연결된 자식 노드들의 번호) 를 구한다.

 

초기값은 자식 노드가 없는 리프노드에 대해 결정하면 된다.

리프노드의 경우 자식이 아무도 없다면 자기 자신이 얼리어답터가 될 수 밖에 없다.

하지만 자신이 얼리어답터가 되지 않는 경우도 일단 카운팅할 수 있도록 DP 테이블에는 INF 값을 저장하였다.

얼리어답터인 부모가 자식 노드를 고려할 때는 자식 노드 중에 얼리어답터가 아닌 리프 노드를 가져오는 것이 이득이므로 이 경우만 특수하게 처리해주었다. (아래 코드에서 v2 == INF 이면 v2 = 0 으로 처리하는 부분)

 

import sys
from array import array
input = sys.stdin.readline
sys.setrecursionlimit(10**6+100)

INF = 987654321
YES, NO = 0, 1

def solve(node, adapter):
    if dp[node][adapter] < INF:
        return dp[node][adapter]

    if adapter == YES:
        ret = 1
        for connected in graph[node]:
            if visit[connected]:
                continue

            visit[connected] = True
            v1 = solve(connected, YES)
            visit[connected] = True
            v2 = solve(connected, NO)
            if v2 == INF:
                v2 = 0
            ret += min(v1, v2)

    else:
        ret = 0
        for connected in graph[node]:
            if visit[connected]:
                continue

            visit[connected] = True
            v1 = solve(connected, YES)
            ret += v1
        if ret == 0:
            visit[node] = False
            return INF

    dp[node][adapter] = ret
    visit[node] = False
    return ret


n = int(input())
graph = [[] for _ in range(n+1)]
for _ in range(n-1):
    s, e = map(int, input().split())
    graph[s].append(e)
    graph[e].append(s)

dp = [[INF, INF] for _ in range(n+1)]
visit = array('B', [False]*(n+1))

visit[1] = True
ans1 = solve(1, YES)
visit[1] = True
ans2 = solve(1, NO)
# print(dp)
print(min(ans1, ans2))

 

반응형