过采样
import numpy as np
import itertools
y_origin = [1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1]
y_origin = np.array(y_origin)
NUM_LABELS = 2
def to_one_hot(data, depth):
return (np.arange(depth)==data[:, None]).astype(np.int32)
def over_sample( y_origin, threshold):
y = to_one_hot(y_origin, NUM_LABELS)
y_counts = np.sum(y, axis=0)
sample_ratio = threshold / y_counts * y
sample_ratio = np.max(sample_ratio, axis=1)
sample_ratio = np.maximum(sample_ratio, 1)
index = ratio_sample(sample_ratio)
# x_token_train = [x_token_train[i] for i in index]
return y_origin[index]
def ratio_sample(ratio):
sample_times = np.floor(ratio).astype(int)
# random sample ratio < 1 (decimal part)
sample_ratio = ratio - sample_times
random = np.random.uniform(size=sample_ratio.shape)
index = np.where(sample_ratio > random)
index = index[0].tolist()
# over sample fixed integer times
row_num = sample_times.shape[0]
for row_index, times in zip(range(row_num), sample_times):
index.extend(itertools.repeat(row_index, times))
return index
over_y = over_sample(y_origin, 24)