#encoding=utf-8
import sys
import math
n,k=map( lambda x:int(x),sys.stdin.readline().strip("\n").split(" ") )
if k==1:
print 1
exit(0)
elif k==2:
print pow( n,1 )
exit( 0 )
root = int( math.log( k,2 ) )
# print root
res=[None]*pow( 2,root )
res.append( 0 )
for i in range( root+1 ):
res[ pow( 2,i ) ]=pow( n,i )
# print res
for i in range( 1,root ):
pre=1
post=pow( 2,i )
start=pow( 2,i )+1
end=pow( 2,i+1 )
# print pre,post,start,end
count=0
for j in range( pre,post ):
res[ start+count ]=res[j]+res[post]
count+=1
res=res[1:]
# print len( res )
count=0
while True:
if len(res)!=k:
res.append(res[count] + res[pow(2, root) - 1])
count+=1
else:
print res[-1]
break