on-policy distillation
on-policy distillation
- 원문 링크
- 요약
- 로컬 llm 을 커스텀 데이터로 학습시키면 기본 성능이 부작용이 발생한다
- 자가 증류 강화학습을 이용해서 이 문제를 해결할 수 있다
- 서론
- llm 학습은 보통 3 단계로 이루어진다
- 프리트레이닝
- 기본 언어 능력
- 미드 트레이닝
- 특화된 분야 전문 능력
- 포스트 트레이닝
- 질문 답변 대화 같은 형식
- 논리 추론의 특화 (강화학습)
- 프리트레이닝
- llm 학습은 보통 3 단계로 이루어진다
- 강화학습의 기본 개념
- on-policy training
- 강화학습용 데이터(롤아웃)가 모델 자기 자신에게서 생성된 것. 거기에 점수를 부여하는 것만 외부의 도구 활용
- 장점
- 스스로의 행동에 대해 피드백을 받는거라 흡수가 잘됨
- 단점
- 생성해야 하는 데이터 분량 대비 흡수할 수 있는 피드백 정보의 분량 비율이 낮음
- 비유하자면, 자기주도 학습을 하겠답시고 학원이나 강의를 전혀 듣지 않으면서 스스로의 능력으로만 문제를 풀고 오답 노트 공부만 하는 학생의 케이스
- off-policy training
- 데이터가 외부에서 나오고 그걸 따라하게 함
- SFT 는 off-policy training 의 특화된 사례라고 할 수 있음
- 마이너스 보상 없이 플러스 보상만 주어짐
- 장점
- 교사 모델로부터 단순 토큰뿐 아니라 토큰의 확률분포 정보까지 받으면 훨씬 더 좋은 정보 (불편추정량에 가까운) 를 받아 따라할 수 있음
- 각 토큰의 샘플링 결과 대신 logit 확률값을 받아올 수 있는 경우
- 교사 모델로부터 단순 토큰뿐 아니라 토큰의 확률분포 정보까지 받으면 훨씬 더 좋은 정보 (불편추정량에 가까운) 를 받아 따라할 수 있음
- 단점
- 모델이 스스로 생각한 경로가 아니라 외부에서 주입된 경로이기 때문에 스스로의 실수를 교정하는 부분에 대한 학습이 어려움
- 특히 문장이 길어질 수록, 학생 모델이 생성 도중에 실수하는 토큰을 생성하는 부분이 생길 수 있는데, 선생모델은 절대 실수를 하지 않는다면 이부분을 배울 수 없음
- 비유하자면, 일타 강사의 강의들을 열심히 들으면서 문제와 모범 답안을 외우기만 하지, 막상 스스로 시간을 정해서 풀고 오답노트를 작성하는 생략하는 것과 같음
- 모델이 스스로 생각한 경로가 아니라 외부에서 주입된 경로이기 때문에 스스로의 실수를 교정하는 부분에 대한 학습이 어려움
- 위 두가지 학습방법의 장점을 결합하면 어떨까?
- 어떤 문제에 대해서 학생이 문제를 풀어놓으면 과외 선생이 그 풀이를 보고 어느 부분에서 잘못되었는지 단계단위로 짚어주고, 다시 풀게 해보는 방식
- 맞았다 틀렸다라는 'sparse' 한 보상 대신, 세세하게 어디에서부터 어긋났다라는 'dense' 한 보상
- 어떤 문제에 대해서 학생이 문제를 풀어놓으면 과외 선생이 그 풀이를 보고 어느 부분에서 잘못되었는지 단계단위로 짚어주고, 다시 풀게 해보는 방식
- on-policy training
- 구체적인 구현 방법
- 손실 함수
- 역 KL 발산도를 활용
- 선생 모델을 기준으로 학생모델이 얼마나 충실하게 정보를 재현할 수 있나
- 일반 KL
- 학생모델을 기준으로 선생모델이 얼마나 충실하게 정보를 재현하나
- 리버스 KL
- KL 의 특성상 위 2개는 같은 값이 아님
- KL 은 거리가 아니라 발산도임
- 장점
- 리워드 해킹이 쉽게 일어나지 못함
- 선생 모델을 기준으로 학생모델이 얼마나 충실하게 정보를 재현할 수 있나
- 역 KL 발산도를 활용
- 구현
- 학생 모델에 질문에 대한 답을 생성 (롤아웃) 을 시킴. 그 과정에서 최종 샘플링된 토큰 외에도 각 토큰의 확률 분포를 남기게 함
- 생성된 토큰들에 대해 선생모델에게 자신의 경우였다면 어떤 확률 분포를 가질지 계산시킴
- 리버스 KL 값을 계산하고 어드밴티지 값을 구함
- 어드밴티지 : 그냥 생성하는 경우 대비 모범답안이 얼마나 더 좋아졌는지를 수치로 표현
- 손실 함수
- 응용사례: 언어모델 특화
- 내부 문서에 특화한 질문-답변 조수 모델을 만든다고 가정하자
- 예제 모델은 qwen3-8b
- 그냥 학습시키면 망각증상이 일어남
- 보통 사용하는 대안
- 기반 모델을 학습시킬때 사용했던 데이터를 구해서 섞어주면 망각을 완화할 수 있다
- 어중간한 규모의 데이터로는 완화에 불과함. 해결은 안됨 (원래의 능력을 95% 이상 회복하지 못함)
- 그리고 기반 모델을 학습할 때 사용했던 데이터를 구하기도 어렵다
- 데이터를 구하기 어려운 문제는 대안이 있긴 하다
- tulu 같은 데이터셋을 이용해서, 거기서 프롬프트만 따서 튜닝 안된 버전의 모델로 대답을 생성하게 만들기
- 32b 모델처럼 더 강한모델보다 8b 모델 스스로 생성한 대답이 효과가 더 좋음
- 이렇게 해서 모델의 원래 능력으로 붙잡아주는 데이터를 '규제화' 데이터라고 한다.
- 이 경우에는 채팅 능력이 보존 대상이니 규제화 데이터를 채팅 형식으로 만듬
- 실험 결과
- 새로 학습시킬 데이터 (원래 모델이 모르는 데이터) 와 규제화 데이터의 혼합비율에 따라 평가 점수가 달라짐
- 규제화 데이터를 최소 30% 는 넣어줘야 능력이 망가지지 않는 것을 확인할 수 있다
- 하지만 일반 능력의 정확도도 85% 에서 80% 정도로 떨어짐
- 그리고 규제화 데이터를 많이 넣을 수록, 새로운 지식의 습득 정도는 당연히 떨어진다
- 새로운 지식을 많이 습득시키면서도 떨어진 일반 수행 능력을 거의 완벽하게 복원할 방법이 바로 오늘 소개한 자가 증류 강화학습
- 기반 모델을 학습시킬때 사용했던 데이터를 구해서 섞어주면 망각을 완화할 수 있다
- 자가 증류 강화학습
- 모델이 새 지식을 습득하면서 망각한 부분만 최소한의 수정으로 복원시킬 방법
- 구현
- 1단계 - 일단 7:3 으로 혼합시켜 훈련 시킨 새버전의 모델을 준비한다. 이것이 학생모델이 된다
- 이 학생 모델은 새로 학습한 내부 데이터에 대한 정답률은 18%에서 36%로 2배 증가 했지만 일반 문답 능력은 85% 에서 79% 로 무시할 수 없을 정도의 하락을 보임
- 2단계 - 증류를 하되, 별도의 교사 모델을 따로 두는 대신, 파인튜닝을 하기 이전 모델을 교사로 활용한다
- 새로운 지식이 아닌 일반 지식에 대해서 학생 모델에 토큰 분포를 생성(롤아웃)하게 하고, 교사모델 (원래 버전) 의 토큰 분포와 리버스 kl 을 계산해서 업데이트 시킨다
- 이렇게 했더니 신규 지식의 정답률은 41%로 오히려 더 올라가고, 기존의 문답 능력은 83% 로 2% 밖에 차이 안나는 수준까지 복구된다!
- 1단계 - 일단 7:3 으로 혼합시켜 훈련 시킨 새버전의 모델을 준비한다. 이것이 학생모델이 된다
- 추가적으로 시도해볼 사항
- 1단계와 2단계를 한 다음에 거기서 끝내지 말고 다시 1 단계를 또 하고 2 단계를 또 하고 하는 식으로 반복을 돌리면 더 추가적인 향상을 기대할 수 있다고 한다
- 내부 문서에 특화한 질문-답변 조수 모델을 만든다고 가정하자
- 그외에 생각해볼 포인트
- 데이터셋의 재활용
- 데이타셋이 부족한 상태에서 같은 데이타셋에 에포크를 반복해서 돌리면, 단순 암기현상이 발생한다
- 단순 반복학습 대신 자가 증류 강화학습을 이용하면 단점을 줄이고 꾸준히 학습시킬 수 있다
- 극단적인 실험 예제
- 딱 한가지 수학문제 풀이 데이터만 이용해서 20번을 반복학습 시킨 경우. 각 1번의 에포크마다 256개의 롤아웃을 만들어서 학습시킴.
- 데이타셋이 부족한 상태에서 같은 데이타셋에 에포크를 반복해서 돌리면, 단순 암기현상이 발생한다
- 일반 학습과의 비교
- 일반 학습은 파라메터를 업데이트하는데에 많은 시간을 소모한다.
- 강화학습은 어떤 파라메터를 업데이트할지 찾아내는데에 대부분의 시간을 소모하고, 그 파라메터를 업데이트하는 데에는 상대적으로 적은 시간을 소모한다
- 그 결과로 생각의 경로에 지름길들이 형성되고 이후에는 그 지름길을 활용하게 된다
- 지속 학습을 위한 on-policy 학습
- 온폴리시 RL 은 망각이 덜한대신 (파라메터 업데이트가 적음) 생각을 다듬는데에 주로 효과가 국한된다.
- 따라서 새로운 정보를 주입시키는 데에는 한계가 있다
- 반면 SFT 나 off-policy RL 로는 망각이 커진다.
- 조금 더 실험해보자
- 강한 모델 (32B) 를 이용해서 tulu3 의 프롬프트에 대한 답변을 만들어서 데이터셋을 생성한다
- 이것은 32b 에 대해서는 KL 발산도가 0 이다. 자기 자신이니까.
- 그럼에도 불구하고 그 데이터셋에 추가 학습 (SFT = 오프폴리시) 을 시키면 kl 발산도가 증가한다. 즉 약간 성능이 하락한다
- 물론 계속 하면 할수록 성능이 올라오긴 하지만 그 속도는 더디다.
- 애초에 KL 발산도가 0 이라는 것은 기대값으로서의 의미이다. 무한히 추가학습을 골고루 시키다보면 원래 위치로 돌아올 수 있다는 의미
- 그에 비해 온폴리시 RL 은 망각 방지에 강한 모습을 보인다
- 꼭 필요한 파라메터만 교정하는 효과 덕분
- 지속 학습에는 어떤 형식으로든 온폴리시 강화학습이 효과적인 방식으로 자리 잡을 듯
- 온폴리시 RL 은 망각이 덜한대신 (파라메터 업데이트가 적음) 생각을 다듬는데에 주로 효과가 국한된다.
- 데이터셋의 재활용