import random
import numpy as np
import torch
from matplotlib import pyplot as plt
from torch import nn, optim, autograd
from visdom import Visdom
# 生成real-data数据集
def data_generator():
"""
预设数据样本分布为8个高斯分布叠加的分布模型
"""
scale = 2.
centers = [
(1, 0), (-1, 0), (0, 1), (0, -1),
(1. / np.sqrt(2), 1. / np.sqrt(2)),
(1. / np.sqrt(2), -1. / np.sqrt(2)),
(-1. / np.sqrt(2), 1. / np.sqrt(2)),
(-1. / np.sqrt(2), -1. / np.sqrt(2))
]
centers = [(scale * x, scale * y) for x, y in centers]
while True:
dataset = []
for i in range(batch_size):
point = np.random.randn(2) * 0.02
center = random.choice(centers)
point[0] += center[0]
point[1] += center[1]
dataset.append(point)
dataset = np.array(dataset, dtype='float32')
dataset /= 1.414 # stdev
yield dataset
# hyper-parameters
hidden_d
wgan-gp
最新推荐文章于 2023-05-25 15:31:18 发布