<think>我们遇到一个错误:`DataType longtype is not supported.(line 1, pos 0)`。这个错误发生在代码的第24行,也就是在转换`machine_record_id`列为LongType的时候:
```scala
val trainDF = spark.sql("SELECT * FROM dwd.fact_machine_learning_data WHERE machine_record_state IS NOT NULL")
.withColumn("machine_record_id", $"machine_record_id".cast("LongType"))
```
错误的原因是Spark SQL不支持字符串"LongType",它需要我们使用`DataTypes`中定义的类型或者使用字符串的别名。正确的做法是使用字符串"long"而不是"LongType"。
在Spark中,我们可以使用以下方式指定类型:
- 使用字符串别名,例如:"long", "string", "double"等。
- 使用`org.apache.spark.sql.types`包中定义的类型,例如:`LongType`。
因此,我们有两种修改方式:
1. 使用字符串别名"long":
```scala
.withColumn("machine_record_id", $"machine_record_id".cast("long"))
```
2. 导入`LongType`并使用它(推荐,因为它是类型安全的):
```scala
import org.apache.spark.sql.types.LongType
...
.withColumn("machine_record_id", $"machine_record_id".cast(LongType))
```
我们选择第二种方式,因为它在编译时就能检查类型是否正确。
另外,我们注意到在代码中同时使用了`import spark.implicits._`,这已经包含了`$`的隐式转换,所以我们可以直接使用`$`符号。
修改后的代码片段如下:
```scala
import org.apache.spark.sql.types.LongType
// ...
val trainDF = spark.sql("SELECT * FROM dwd.fact_machine_learning_data WHERE machine_record_state IS NOT NULL")
.withColumn("machine_record_id", $"machine_record_id".cast(LongType))
```
同样,对testDF也需要做同样的修改。
此外,我们还注意到日志中有关于Hive的jar文件路径不存在的警告。这些警告可能不会影响程序运行,但为了确保Hive支持正常工作,我们需要确保Spark配置正确的Hive依赖。通常,我们可以通过设置`spark.sql.warehouse.dir`和`hive.metastore.uris`来连接Hive元数据服务。在创建SparkSession时,我们已经调用了`.enableHiveSupport()`,这是正确的。
另外,日志中还有log4j的警告,我们可以通过提供一个log4j.properties文件来配置日志输出。不过,这不是导致程序失败的原因,我们可以暂时忽略。
因此,我们主要修改数据类型转换的部分。
修改后的完整代码:
</think>您的错误是由于在 `cast` 方法中使用了字符串 `"LongType"` 而不是 Spark SQL 的 `LongType` 对象导致的。以下是修复方案:
### 问题分析
```scala
// 错误用法
.withColumn("machine_record_id", $"machine_record_id".cast("LongType"))
// 正确用法
.withColumn("machine_record_id", $"machine_record_id".cast(LongType))
```
在 Spark 中,`cast()` 方法需要传入 `DataType` 对象,而不是字符串名称。
### 完整修复后的代码
```scala
package gs8
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.{SparkSession, DataFrame}
import org.apache.spark.sql.types.{LongType, DoubleType} // 导入LongType
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.sql.functions.{col, when}
import java.util.Properties
object shujuwaqu2 {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName("RandomForestModel")
.master("local[*]")
.enableHiveSupport()
.getOrCreate()
Logger.getLogger("org").setLevel(Level.ERROR)
import spark.implicits._
// 使用LongType对象而不是字符串"LongType"
val trainDF = spark.sql("SELECT * FROM dwd.fact_machine_learning_data WHERE machine_record_state IS NOT NULL")
.withColumn("machine_record_id", $"machine_record_id".cast(LongType)) // 修复1: 使用LongType对象
val featureCols = Array(
"machine_record_mainshaft_speed",
"machine_record_mainshaft_multiplerate",
"machine_record_mainshaft_load",
"machine_record_feed_speed",
"machine_record_feed_multiplerate",
"machine_record_pmc_code",
"machine_record_circle_time",
"machine_record_run_time",
"machine_record_effective_shaft",
"machine_record_amount_process",
"machine_record_use_memory",
"machine_record_free_memory",
"machine_record_amount_use_code",
"machine_record_amount_free_code"
)
val assembler = new VectorAssembler()
.setInputCols(featureCols)
.setOutputCol("features")
val trainData = assembler.transform(trainDF)
.select($"machine_record_state".alias("label"), $"features")
val rf = new RandomForestClassifier()
.setLabelCol("label")
.setFeaturesCol("features")
.setNumTrees(20)
.setMaxDepth(8)
val pipeline = new Pipeline().setStages(Array(rf))
val model = pipeline.fit(trainData)
println("随机森林模型训练完成!")
// 同样修复这里
val testDF = spark.sql("SELECT * FROM dwd.fact_machine_learning_data_test")
.withColumn("machine_record_id", $"machine_record_id".cast(LongType)) // 修复2: 使用LongType对象
val testData = assembler.transform(testDF)
.select($"machine_record_id", $"features")
val predictions = model.transform(testData)
.select($"machine_record_id", $"prediction".alias("machine_record_state"))
println("预测完成,前5条结果:")
predictions.show(5, truncate = false)
val url = "jdbc:mysql://xueai:3306/shtd_industry?useSSL=false&serverTimezone=UTC"
val props = new Properties()
props.put("user", "root")
props.put("password", "admin")
props.put("driver", "com.mysql.jdbc.Driver")
predictions.write
.mode("overwrite")
.jdbc(url, "ml_result", props)
println("预测结果已成功写入 MySQL 表 ml_result!")
println("Hive 中 dwd.fact_machine_learning_data_test 预测完毕。")
println("请在 MySQL 中执行以下查询语句查看结果:")
println(
"""
|SELECT * FROM ml_result
|WHERE machine_record_id IN (1,8,20,28,36);
|""".stripMargin)
spark.stop()
}
}
```
### 关键修复点解释:
1. **正确导入 LongType**:
```scala
import org.apache.spark.sql.types.{LongType, DoubleType}
```
从 `spark.sql.types` 导入必要的类型对象
2. **使用类型对象而非字符串**:
```scala
.withColumn("machine_record_id", $"machine_record_id".cast(LongType))
```
将 `"LongType"` 字符串替换为 `LongType` 对象
3. **修复两个地方**:
在训练数据和测试数据的转换中都使用了正确的类型对象
### 其他优化建议:
1. **添加空值处理**:
```scala
val trainDF = spark.sql("SELECT * FROM dwd.fact_machine_learning_data WHERE machine_record_state IS NOT NULL")
.na.fill(0, featureCols) // 填充数值型特征的空值
.withColumn("machine_record_id", $"machine_record_id".cast(LongType))
```
2. **MySQL 连接池配置**:
```scala
props.put("rewriteBatchedStatements", "true") // 启用批量写入
props.put("batchsize", "50000") // 设置批量大小
```
3. **预测结果格式化**:
```scala
val predictions = model.transform(testData)
.select(
format_string("%d", $"machine_record_id").as("machine_record_id"), // 确保整数格式
$"prediction".cast(IntegerType).alias("machine_record_state") // 转换为整数
)
```
### 完整优化版本:
```scala
package gs8
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.{SparkSession, DataFrame}
import org.apache.spark.sql.types.{LongType, IntegerType}
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.sql.functions.{col, format_string}
import java.util.Properties
object shujuwaqu2 {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName("RandomForestModel")
.master("local[*]")
.enableHiveSupport()
.getOrCreate()
Logger.getLogger("org").setLevel(Level.ERROR)
import spark.implicits._
val featureCols = Array(
"machine_record_mainshaft_speed",
"machine_record_mainshaft_multiplerate",
"machine_record_mainshaft_load",
"machine_record_feed_speed",
"machine_record_feed_multiplerate",
"machine_record_pmc_code",
"machine_record_circle_time",
"machine_record_run_time",
"machine_record_effective_shaft",
"machine_record_amount_process",
"machine_record_use_memory",
"machine_record_free_memory",
"machine_record_amount_use_code",
"machine_record_amount_free_code"
)
// 优化:空值填充 + ID 类型转换
val trainDF = spark.sql("SELECT * FROM dwd.fact_machine_learning_data WHERE machine_record_state IS NOT NULL")
.na.fill(0, featureCols)
.withColumn("machine_record_id", $"machine_record_id".cast(LongType))
val assembler = new VectorAssembler()
.setInputCols(featureCols)
.setOutputCol("features")
val trainData = assembler.transform(trainDF)
.select($"machine_record_state".alias("label"), $"features")
val rf = new RandomForestClassifier()
.setLabelCol("label")
.setFeaturesCol("features")
.setNumTrees(50) // 增加树的数量提高精度
.setMaxDepth(10) // 增加深度捕捉复杂关系
.setSubsamplingRate(0.8)
val pipeline = new Pipeline().setStages(Array(rf))
val model = pipeline.fit(trainData)
println("随机森林模型训练完成!")
// 同样处理测试数据
val testDF = spark.sql("SELECT * FROM dwd.fact_machine_learning_data_test")
.na.fill(0, featureCols)
.withColumn("machine_record_id", $"machine_record_id".cast(LongType))
val testData = assembler.transform(testDF)
.select($"machine_record_id", $"features")
// 优化:格式化ID并转换预测值为整数
val predictions = model.transform(testData)
.select(
format_string("%d", $"machine_record_id").as("machine_record_id"),
$"prediction".cast(IntegerType).alias("machine_record_state")
)
println("预测完成,前5条结果:")
predictions.show(5, truncate = false)
val url = "jdbc:mysql://xueai:3306/shtd_industry?useSSL=false&serverTimezone=UTC"
val props = new Properties()
props.put("user", "root")
props.put("password", "admin")
props.put("driver", "com.mysql.jdbc.Driver")
props.put("rewriteBatchedStatements", "true") // 批量写入优化
props.put("batchsize", "50000") // 批量大小
// 添加写入异常处理
try {
predictions.write
.mode("overwrite")
.jdbc(url, "ml_result", props)
println("预测结果已成功写入 MySQL 表 ml_result!")
// 验证写入结果
val readBack = spark.read.jdbc(url, "ml_result", props)
println(s"写入记录数: ${readBack.count()}")
} catch {
case e: Exception =>
println(s"写入MySQL失败: ${e.getMessage}")
e.printStackTrace()
}
println("Hive 中 dwd.fact_machine_learning_data_test 预测完毕。")
println("请在 MySQL 中执行以下查询语句查看结果:")
println("SELECT * FROM ml_result LIMIT 5;")
spark.stop()
}
}
```