import findspark
findspark.init()
from pyspark import SparkContext
from pyspark.sql import SparkSession, Row
from pyspark.sql.functions import col, max
from pyspark.sql.types import StringType, StructField, StructType,IntegerType,FloatType
import json
import csv
import os
from pyspark.sql import functions as F
sc = SparkContext('local', 'spark_project')
sc.setLogLevel('WARN')
spark = SparkSession.builder.getOrCreate()
schemaString = "Country,Year,Total,Coal,Oil,Gas,Cement,Flaring,Other,PerCapita"
fields = []
for field in schemaString.split(","):
if field == 'Country':
a = StructField(field, StringType(), True)
elif field == 'Year':
a = StructField(field,IntegerType(),True)
else:
a = StructField(field,FloatType(),True)
fields.append(a)
# fields.append(a)
schema = StructType(fields)
CO2Rdd1 = sc.textFile('/user/xmw/dataset.csv')
CO2Rdd = CO2Rdd1.map(lambda x:x.split("\t")).map(lambda p: Row(p[0],int(p[1]),float(p[2]),float(p[3]),float(p[4]),float(p[5]),float(p[6]),float(p[7]),float(p[8]),float(p[9])))
mdf = spark.createDataFrame(CO2Rdd, schema)
mdf.createOrReplaceTempView("usInfo")
mdf.show()
# 展示数据表结构
mdf.printSchema()
# 1.查询2021年中国的排放总量
df1 = spark.sql(
"SELECT * FROM usInfo WHERE Country = 'China' AND Year = '2021'")
# 将结果保存为文本文件
df1.repartition(1).write.csv('/user/xmw/df1',mode="overwrite")
df1.show()
# 2.统计总能源使用量
df2 = mdf.selectExpr("SUM(Total) AS total_energy")
df2.show()
df2.repartition(1).write.csv('/user/xmw/df2',mode="overwrite")
# 3.计算每个国家的平均人均能源使用量(总的),不是每一年的
df3 = mdf.groupBy("Country").agg({"PerCapita": "avg"}).withColumnRenamed("avg(PerCapita)", "AvgPerCapita")
df3.show()
df3.repartition(1).write.csv('/user/xmw/df3',mode="overwrite")
#4.获取排放总量最高的国家和对应的年份
df4 = mdf.select("Country", "Year", "Total").orderBy(F.col("Total").desc())
df4.show(1)
df4.repartition(1).write.csv('/user/xmw/df4',mode="overwrite")
print("排放总量最高的国家:", df4.first()["Country"])
print("对应的年份:", df4.first()["Year"])