Weight&Biases教程

简介

Weights & Biases 可以帮助跟踪你的机器学习项目。使用工具记录运行中的超参数和输出指标(Metric),然后对结果进行可视化和比较,并快速与同事分享你的发现。

好处

首先它和tensorboard的功能类似,但是区别在于可以通过真正的网址随时进行访问,同时他人也能看到你的效果,省去了麻烦的步骤,同时还能方便自己随时随地监工自己的网络运行情况,同时它的sweeps能自己调参优化模型
在这里插入图片描述

工具
仪表盘:跟踪实验、可视化结果。
报告:保存和分享可复制的成果/结论。
Sweeps:通过调节超参数来优化模型
Artifacts : 数据集和模型版本化,流水线跟踪。

网址

https://wandb.ai/site,注册一个账号,个人是免费的,然后创建一个项目,然后按照官网提示安装pip install wandb

测试demo

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import wandb


batch_size = 200
learning_rate = 0.01
epochs = 30

wandb.init(
    # Set the project where this run will be logged
    project=" ",#写自己的
    entity=" ",#写自己的
    # Track hyperparameters and run metadata
    config={
   
        "learning_rate": learning_rate,
        "architecture": "MLP",
        "dataset": "MNIST",
        "epochs": epochs,
    })

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       # transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        # transforms.Normalize((0.1307,), (0.3081,))
    ])),
    batch_size=batch_size, shuffle=True)


class MLP(nn.Module):

    def __init__(self):
        super(MLP, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(784, 200),
            nn.LeakyReLU(inplace=True),
            nn.Linear(200, 200),
            nn.LeakyReLU(inplace=True),
            nn.Linear(200, 10),
            nn.LeakyReLU(inplace=True),
        )

    def forward(self, x):
        x = self.model(x)

        return x


device = torch.device('cuda:0')
net = MLP().to(device)
optimizer = optim.SGD(net.parameters(), lr=learning_rate)
criteon = nn.CrossEntropyLoss().to(device)


global_step = 0

for epoch in range(epochs):
    net.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.view(-1, 28 * 28)
        data, target = data.to(device), target.cuda()

        logits = net(data)
        loss = criteon(logits, target)

        optimizer.zero_grad()
        loss.backward()
        # print(w1.grad.norm(), w2.grad.norm())
        optimizer.step()

        global_step += 1

        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))
            #展示数据
            wandb.log({
   "Train loss": loss.item()})
    net.eval()
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        data = data.view(-1, 28 * 28)
        data, target = data.to(device), target.cuda()
        logits = net(data)
        test_loss += criteon(logits, target).item()

        pred = logits.argmax(dim=1)
        correct += pred.eq(target).float().sum().item()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
<think>好的,我现在需要帮助用户解决在Python中遇到的`java.lang.ClassNotFoundException: org.mariadb.jdbc.Driver`和`Py4JJavaError`错误。首先,我要回顾用户之前的问题和我的回答,确保这次回答一致且深入。用户提到参考了站内引用,但里面的关键词主要是Java相关技术,可能与当前问题关联不大,所以需要专注于用户的实际需求。 首先,用户遇到的错误通常出现在使用PySpark连接MariaDB时,因为JVM找不到对应的JDBC驱动。在之前的回答中,我提到了下载驱动、配置路径以及验证的方法。现在需要详细检查可能的遗漏点,并提供更具体的解决方案。 接下来,用户提到的`Py4JJavaError`通常是因为PySpark在提交任务到JVM时出现异常,而根本原因还是驱动类未找到。需要确认用户是否正确配置了Spark会话的类路径,以及驱动版本是否兼容当前的MariaDB服务器版本。 可能的解决步骤包括: 1. 确保驱动JAR文件已正确下载,并且路径正确。 2. 在Spark配置中明确指定驱动路径。 3. 检查驱动版本是否与MariaDB服务器兼容。 4. 处理依赖冲突,例如其他JDBC驱动的存在可能引起问题。 5. 验证驱动是否成功加载,可以通过测试连接或打印版本号来确认。 此外,用户可能遇到的常见错误包括文件权限问题、路径错误或版本不匹配,需要逐一排查。还要考虑到用户可能在不同的环境中运行代码,比如本地开发环境或集群环境,配置方式可能有所不同。 最后,需要确保回答结构清晰,步骤明确,并提供验证方法,让用户能够自行确认问题是否解决。同时,生成的相关问题需要与当前主题紧密相关,帮助用户进一步理解可能遇到的类似问题。</think>### Python中解决`Py4JJavaError: org.mariadb.jdbc.Driver`类缺失的完整方案 #### 问题根源 该错误是**PySpark或其他Java交互场景**中因JVM未找到MariaDB JDBC驱动导致,常见于以下操作: 1. 使用`spark.read.jdbc()`读取MariaDB数据 2. 通过Py4J调用Java代码访问MariaDB 3. 未正确配置Spark的依赖加载路径[^1] --- #### 分步解决方案 ##### 1. **获取正确版本的MariaDB驱动** - **官方下载**:[MariaDB JDBC驱动下载页](https://mariadb.com/downloads/connectors/connectors-data-access/java8-connector) - **版本选择**: - MariaDB 10.6+ → 使用驱动`3.x` - MariaDB 10.5及以下 → 使用驱动`2.7.x` - *避免使用MySQL驱动替代*(易导致协议不兼容) ##### 2. **配置PySpark的JAR加载路径** 在初始化SparkSession时显式指定驱动路径: ```python from pyspark.sql import SparkSession spark = SparkSession.builder \ .appName("MariaDB_Example") \ .config("spark.driver.extraClassPath", "/绝对路径/mariadb-java-client-3.3.2.jar") \ .config("spark.executor.extraClassPath", "/绝对路径/mariadb-java-client-3.3.2.jar") \ .getOrCreate() ``` *关键点*:`spark.driver`和`spark.executor`需同时配置(集群模式必填) ##### 3. **验证驱动加载(代码级检查)** ```python # 方法1:通过JDBC连接测试 url = "jdbc:mariadb://your_host:3306/your_db" properties = { "user": "your_user", "password": "your_password", "driver": "org.mariadb.jdbc.Driver" } try: df = spark.read.jdbc(url=url, table="your_table", properties=properties) df.show(2) except Exception as e: print("驱动加载失败:", str(e)) # 方法2:直接调用Java类(需安装jpype1) import jpype jpype.startJVM(classpath=["/绝对路径/mariadb-java-client-3.3.2.jar"]) print(jpype.JClass("org.mariadb.jdbc.Driver").getVersion()) ``` ##### 4. **依赖冲突排查** 若仍报错,检查是否有其他JDBC驱动干扰: ```bash # 查看Spark已加载的JAR列表 find $SPARK_HOME/jars -name "*jdbc*.jar" # 移除冲突的JAR(如mysql-connector-java) mv $SPARK_HOME/jars/mysql-connector-java-*.jar /tmp/ ``` --- #### 高级场景处理 | 场景 | 解决方案 | |--------------------------|------------------------------------------------------------------------| | **集群部署** | 将JAR上传到HDFS,配置`spark.jars hdfs:///path/to/mariadb-java-client.jar` | | **Jupyter Notebook环境** | 在启动时添加参数:`pyspark --driver-class-path /path/to/mariadb-java-client.jar` | | **Docker容器** | 在Dockerfile中添加:`COPY mariadb-java-client-3.3.2.jar /opt/spark/jars/` | --- #### 错误案例解析 **案例现象**: `Py4JJavaError: An error occurred while calling o36.jdbc. : java.lang.ClassNotFoundException: org.mariadb.jdbc.Driver` **排查步骤**: 1. 检查`spark-defaults.conf`是否包含`spark.jars`配置 2. 在Linux终端执行:`jar tf mariadb-java-client-3.3.2.jar | grep Driver.class` *应输出:`org/mariadb/jdbc/Driver.class`* 3. 检查JAR文件权限:`chmod 644 mariadb-java-client-3.3.2.jar` --- [^1]: Spark的类加载机制遵循双亲委派模型,用户自定义JAR需通过`extraClassPath`显式声明
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值