Reset()重置包括坐标点
reset()重置不包括坐标点
这里的render()参考之前看过的一篇文章(具体是哪篇忘了,没有收藏)使用matplotlib实现
import warnings
import numpy as np
import matplotlib.pyplot as plt
class TSPEnvironment:
"""
__init__() parm: num city, coordinate_dimension, box size
step() and reset() return: (coordinates, path, valid) -> state, reward, done
"""
def __init__(self, num_cities, coordinate_dimension=2, box_size=1.0):
assert coordinate_dimension >= 2, "coordinate_dimension must >= 2 !"
self.num_cities = num_cities
self.coordinate_dimension = coordinate_dimension
self.box_size = box_size
self.coordinates, self.cities_coordinates, self.path, self.now_location = None, None, None, None
self.done = False
self.total_distance = 0.0
self.__init_environment = self.Reset
self.__init_environment()
def reset(self, start_city=None):
if start_city is not None:
assert start_city < self.num_cities, "Start city must < num of city !!!"
self.now_location = start_city if start_city is not None else np.random.choice(
list(self.cities_coordinates.keys()))
self.path = [self.now_location]
self.done = False
self.total_distance = 0.0
valid = self.get_valid_cities(self.path, self.coordinates)
coordinates = np.array([i for i in self.coordinates])
path = [i for i in self.path]
return (coordinates, path, valid), 0.0, self.done
def Reset(self, start_city=None):
if start_city is not None: