:question:문제 설명

N개의 행성이 있고 행성끼리 연결을 해야하는데 최소 비용으로 연결을 해야하는 문제

:pencil2:코드

import sys
input = sys.stdin.readline


def find(x):
  if parent[x] != x:
    parent[x] = find(parent[x])
  return parent[x]

def union(x, y):
  x = find(x)
  y = find(y)

  if x < y:
      parent[y] = x
  elif x > y:
      parent[x] = y

N = int(input())
arr = [list(map(int, input().split())) for _ in range(N)]
parent = [i for i in range(N)]
cost = []
for i in range(N):
  for j in range(i + 1, N):
    cost.append((arr[i][j], i, j))

cost.sort()
ans = 0
for c, x, y in cost:
  x = find(x)
  y = find(y)

  if x != y:
      union(x, y)
      ans += c

print(ans)

:memo:풀이

오랜만에 풀어보는 최소 스패닝 트리 문제였다.

최소 스패닝 트리란 주어진 그래프의 모든 정점들을 연결하는 부분 그래프 중에서 정점간의 가중치의 합이 최소인 트리를 말한다.

이 문제를 풀이하는 방법 중 하나는 Union-Find 자료구조를 이용하는 것이다.

Union 단계에서는 두 정점을 하나의 부모 아래 묶어주는 단계이다.

예를들어 2번 노드의 부모노드는 1번이고 4번 노드의 부모가 3번이면 2번 노드와 4번 노드를 묶어주기 위해 1번 노드와 3번 노드를 묶어주는 것이다.

Union 단계에서 쓰이는게 Find인데 해당 노드의 부모를 찾아주는 함수이다.

이때 경로압축 방법을 써서 각 노드의 바로 위 노드가 아닌 가장 위에있는 부모 노드를 부모로 설정하여 시간을 단축 할 수 있다.

아래 방법은 경로압축 방법을 쓰지 않은 경우인데 위 코드랑 차이점은 parent[x] 를 바로바로 갱신해주지 않는다는 점이다.

def find(x):
    if parent[x] != x:
        return find(parent[x])
   	return x

문제로 다시 돌아가서, cost라는 배열에 (비용, 노드1, 노드2) 순서의 형태로 담아준 다음에 정렬을 해준다.

정렬을 해주는 이유는 가장 작게 비용을 쓰기 위해서이다.

cost에서 하나씩 꺼면서 노드1과 노드2의 부모가 같지 않다면 두개를 연결해주고 최종 비용에 현재 비용을 더해주면 끝!

댓글남기기