def policy_iteration(env, policy, dicount_factor=1.0):
while True:
V = policy_evaluation(policy, env, dicount_factor)
policy_stable = True
for s in range(env.nS):
old_action = np.argmax(policy[s])
action_values = np.zeros(env.nA)
for a in range(env.nA):
for prob, next_state, reward, done in env.P[s][a]:
action_values[a] += prob * (reward + dicount_factor * V[next_state])
best_action = np.argmax(action_values)
if old_action != best_action:
policy_stable = False
policy[s] = np.eye(env.nA)[best_action]
if policy_stable:
return policy, V