导入所需要的包
(略)
解压加载进来的数据集
!unzip -qo data/data115112/ChineseStyle.zip -d data
label = paddle.to_tensor([1,0], dtype='int64')
one_hot_label =paddle.nn.functional.one_hot(label,num_classes=2)
print(one_hot_label )
制作数据集
#定义ChineseStyleDataset数据集类
class ChineseStyleDataset(Dataset):
#构造数据集和标签集
def __init__(self,transforms=None,train="train"):
super().__init__()
self.transforms=transforms
self.datas=list() #创建data列表成员,存放图像数据
self.labels=list() #创建labels列表成员,存放标签数据
#self.temps=list()
font_style=[("lishu",0),("xingkai",1)]
#遍历train文件夹下的xingkai图片
for font_tuple in font_style:
font_name=font_tuple[0]
font_val=font_tuple[1]
font_path="data/ChineseStyle/{}/{}".format(train,font_name)
for filename in os.listdir(font_path):
if ".ipynb_checkpoints" in filename: #图片文件内会自动包含.ipynb_checkpoints文件,需要排除掉
continue
img_path=os.path.join(font_path,filenam