import java.util.*;
public class Main {
public static void main(String[] args) {
Scanner cin = new Scanner(System.in);
int k,m;
m = cin.nextInt();
k = cin.nextInt();
int res = 0;
while(true){
m--;
if(m < 0)break;
res++;
if(res % k == 0)m++;
}
System.out.println(res);
}
}