pyspark学习笔记
- Show the distinct VOTER_NAME entries
- Filter voter_df where the VOTER_NAME is 1-20 characters in length
- Filter out voter_df where the VOTER_NAME contains an underscore
- Show the distinct VOTER_NAME entries again
- Split the content of _c0 on the tab character (aka, '\t')
- Add the columns folder, filename, width, and height
- Add split_cols as a column
spark 分布式存储
# Don't change this query
query = "FROM flights SELECT * LIMIT 10"
# Get the first 10 rows of flights
flights10 = spark.sql(query)
# Show the results
flights10.show()
Pandafy a Spark DataFrame
使用pandas的形式可视化数据框
<script.py> output:
origin ... N
0 SEA ... 8
1 SEA ... 98
2 SEA ... 2
3 SEA ... 450
4 PDX ... 144
[5 rows x 3 columns]
# Don't change this query
query = "SELECT origin, dest, COUNT(*) as N FROM flights GROUP BY origin, dest"
# Run the query
flight_counts = spark.sql(query)
# Convert the results to a pandas DataFrame
pd_counts = flight_counts.toPandas()
# Print the head of pd_counts
print(pd_counts.head())
读文件
# Don't change this file path
file_path = "/usr/local/share/datasets/airports.csv"
# Read in the airports data
airports = spark.read.csv(file_path, header=True)
# Show the data
airports.show()
Use the spark.table() method with the argument "flights" to create a DataFrame containing the values of the flights table in the .catalog. Save it as flights.
Show the head of flights using flights.show(). The column air_time contains the duration of the flight in minutes.
Update flights to include a new column called duration_hrs, that contains the duration of each flight in hours.
以下操作均是对dataframe进行的
# Create the DataFrame flights
flights = spark.table("flights")
# Show the head
flights.show()
# Add duration_hrs
flights = flights.withColumn("duration_hrs", flights.air_time/60)
Filtering Data
筛选数据
# Filter flights by passing a string
long_flights1 = flights.filter("distance > 1000")
# Filter flights by passing a column of boolean values
long_flights2 = flights.filter(flights.distance > 1000)
# Print the data to check they're equal
long_flights1.show()
long_flights2.show()
.show()是展示数据
按照指定的条件筛选,和pandas非常的像
# Select the first set of columns
selected1 = flights.select("tailnum", "origin", "dest")
# Select the second set of columns
temp = flights.select(flights.origin, flights.dest, flights.carrier)
#这个列名的选择很像R里面的
# Define first filter
filterA = flights.origin == "SEA"
# Define second filter
filterB = flights.dest == "PDX"
# Filter the data, first by filterA then by filterB
selected2 = temp.filter(filterA).filter(filterB)
alias()
selectExpr
可以重命名列
selectExpr方法本质与select方法中使用expr函数是一样的,都是用来构建复杂的表达式,下面我们可以看几个例子。
df.selectExpr("appid as newappid").show()
上面这行代码,就是选择appid列并将appid重命名为newappid。
df.selectExpr("count(distinct(appid)) as count1", "count(distinct(brand)) as count2").show()
上面这行代码,就是计算appid去重后的数量,还有brand去重后的数量。
聚合函数
# Find the shortest flight from PDX in terms of distance
flights.filter(flights.origin == "PDX").groupBy().min("distance").show()
# Find the longest flight from SEA in terms of air time
flights.filter(flights.origin == "SEA").groupBy().max("air_time").show()
内置函数
```r
# Find the shortest flight from PDX in terms of distance
flights.filter(flights.origin == "PDX").groupBy().min("distance").show()
# Find the longest flight from SEA in terms of air time
flights.filter(flights.origin == "SEA").groupBy().max("air_time").show()
<script.py> output:
+-------------+
|min(distance)|
+-------------+
| 106|
+-------------+
+-------------+
|max(air_time)|
+-------------+
| 409|
+-------------+
# Average duration of Delta flights
flights.filter(flights.carrier == "DL").filter(flights.origin == "SEA").groupBy().avg("air_time").show()
# Total hours in the air
flights.withColumn("duration_hrs", flights.air_time/60).groupBy().sum("duration_hrs").show()
withColumn给df新增一列
Create a DataFrame called by_plane that is grouped by the column tailnum.
Use the .count() method with no arguments to count the number of flights each plane made.
Create a DataFrame called by_origin that is grouped by the column origin.
Find the .avg() of the air_time column to find average duration of flights from PDX and SEA.
# Group by tailnum
by_plane = flights.groupBy("tailnum")
# Number of flights each plane made
by_plane.count().show()
# Group by origin
by_origin = flights.groupBy("origin")
# Average duration of flights from PDX and SEA
by_origin.avg("air_time").show()
其实也是需要导入聚合函数的包的
# Import pyspark.sql.functions as F
import pyspark.sql.functions as F
# Group by month and dest
by_month_dest = flights.groupBy("month", "dest")
# Average departure delay by month and destination
by_month_dest.avg("dep_delay").show()
# Standard deviation of departure delay
by_month_dest.agg(F.stddev("dep_delay")).show()
+-----+----+----------------------+
|month|dest|stddev_samp(dep_delay)|
+-----+----+----------------------+
| 11| TUS| 3.0550504633038935|
| 11| ANC| 18.604716401245316|
| 1| BUR| 15.22627576540667|
| 1| PDX| 5.677214918493858|
| 6| SBA| 2.380476142847617|
| 5| LAX| 13.36268698685904|
| 10| DTW| 5.639148871948674|
| 6| SIT| NaN|
| 10| DFW| 45.53019017606675|
| 3| FAI| 3.1144823004794873|
| 10| SEA| 18.70523227029577|
| 2| TUS| 14.468356276140469|
| 12| OGG| 82.64480404939947|
| 9| DFW| 21.728629347782924|
| 5| EWR| 42.41595968929191|
| 3| RDM| 2.16794833886788|
| 8| DCA| 9.946523680831074|
| 7| ATL| 22.767001039582183|
| 4| JFK| 8.156774303176903|
| 10| SNA| 13.726234873756304|
+-----+----+----------------------+
only showing top 20 rows
join 链接表
# Examine the data
airports.show()
# Rename the faa column
airports = airports.withColumnRenamed("faa", "dest")
# Join the DataFrames
flights_with_airports = flights.join(airports, on="dest", how="leftouter")
# Examine the new DataFrame
flights_with_airports.show()
Machine Learning Pipelines
拼接dataframe
# Rename year column
planes = planes.withColumnRenamed("year", "plane_year")
# Join the DataFrames
model_data = flights.join(planes, on="tailnum", how="leftouter")
cast
可以转换列的数据类型
result = table1.join(table1,['字段'],"full").withColumn("名称",col("字段")/col("字段"))
新增一列数据,数据的内容是col("字段")/col("字段")
# To convert the type of a column using the .cast() method, you can write code like this:
dataframe = dataframe.withColumn("col", dataframe.col.cast("new_type"))
# Cast the columns to integers
model_data = model_data.withColumn("arr_delay", model_data.arr_delay.cast("integer"))
model_data = model_data.withColumn("air_time", model_data.air_time.cast("integer"))
model_data = model_data.withColumn("month", model_data.month.cast("integer"))
model_data = model_data.withColumn("plane_year", model_data.plane_year.cast("integer"))
# Create is_late
model_data = model_data.withColumn("is_late", model_data.arr_delay > 0)
# Convert to an integer
model_data = model_data.withColumn("label", model_data.is_late.cast("integer"))
# Remove missing values
model_data = model_data.filter("arr_delay is not NULL and dep_delay is not NULL and air_time is not NULL and plane_year is not NULL")
# Create a StringIndexer 字符型
carr_indexer = StringIndexer(inputCol="carrier", outputCol="carrier_index")
# Create a OneHotEncoder 独热编码
carr_encoder = OneHotEncoder(inputCol="carrier_index", outputCol="carrier_fact")
pipeline
# Import Pipeline
from pyspark.ml import Pipeline
# Make the pipeline
flights_pipe = Pipeline(stages=[dest_indexer, dest_encoder, carr_indexer, carr_encoder, vec_assembler])
fit_transform
# Fit and transform the data
piped_data = flights_pipe.fit(model_data).transform(model_data)
划分数据集
# Split the data into training and test sets 类似于train_test_split
training, test = piped_data.randomSplit([.6, .4])
逻辑回归
# Import LogisticRegression
from pyspark.ml.classification import LogisticRegression
# Create a LogisticRegression Estimator
lr = LogisticRegression()
评价指标
# Import the evaluation submodule
import pyspark.ml.evaluation as evals
# Create a BinaryClassificationEvaluator
evaluator = evals.BinaryClassificationEvaluator(metricName="areaUnderROC")
Make a grid
网格搜索
# Import the tuning submodule
import pyspark.ml.tuning as tune
# Create the parameter grid
grid = tune.ParamGridBuilder()
# Add the hyperparameter
grid = grid.addGrid(lr.regParam, np.arange(0, .1, .01))
grid = grid.addGrid(lr.elasticNetParam, [0, 1])
# Build the grid
grid = grid.build()
交叉验证
# Create the CrossValidator
cv = tune.CrossValidator(estimator=lr,
estimatorParamMaps=grid,
evaluator=evaluator
)
模型评估
# Use the model to predict the test set
test_results = best_lr.transform(test)
# Evaluate the predictions
print(evaluator.evaluate(test_results))
drop
# Load the CSV file
aa_dfw_df = spark.read.format('csv').options(Header=True).load('AA_DFW_2018.csv.gz')
# Add the airport column using the F.lower() method
aa_dfw_df = aa_dfw_df.withColumn('airport', F.lower(aa_dfw_df['Destination Airport']))
# Drop the Destination Airport column
aa_dfw_df = aa_dfw_df.drop(aa_dfw_df['Destination Airport'])
# Show the DataFrame
aa_dfw_df.show()
<script.py> output:
+-----------------+-------------+-----------------------------+-------+
|Date (MM/DD/YYYY)|Flight Number|Actual elapsed time (Minutes)|airport|
+-----------------+-------------+-----------------------------+-------+
| 01/01/2018| 0005| 498| hnl|
| 01/01/2018| 0007| 501| ogg|
| 01/01/2018| 0043| 0| dtw|
| 01/01/2018| 0051| 100| stl|
| 01/01/2018| 0075| 147| dca|
| 01/01/2018| 0096| 92| stl|
| 01/01/2018| 0103| 227| sjc|
| 01/01/2018| 0119| 517| ogg|
| 01/01/2018| 0123| 489| hnl|
| 01/01/2018| 0128| 141| mco|
| 01/01/2018| 0132| 201| ewr|
| 01/01/2018| 0140| 215| sjc|
| 01/01/2018| 0174| 140| rdu|
| 01/01/2018| 0190| 68| sat|
| 01/01/2018| 0200| 215| sfo|
| 01/01/2018| 0209| 169| mia|
| 01/01/2018| 0217| 178| las|
| 01/01/2018| 0229| 534| koa|
| 01/01/2018| 0244| 115| cvg|
| 01/01/2018| 0262| 159| mia|
+-----------------+-------------+-----------------------------+-------+
only showing top 20 rows
Saving a DataFrame in Parquet format
parquet压缩的意思
# Combine the DataFrames into one
df3 = df1.union(df2)
# Save the df3 DataFrame in Parquet format
df3.write.parquet('AA_DFW_ALL.parquet', mode='overwrite')
# Read the Parquet file into a new DataFrame and run a count
print(spark.read.parquet('AA_DFW_ALL.parquet').count())
<script.py> output:
df1 Count: 139359
df2 Count: 119911
259270
createOrReplaceTempView
createOrReplaceTempView 的作用是创建一个临时的表 , 一旦创建这个表的会话关闭 , 这个表>也会立马消失 其他的SparkSession 不能共享应已经创建的临时表
createGlobalTempView 创建一个全局的临时表 , 这个表的生命周期是 整个Spark应用程序 ,
只要Spark 应用程序不关闭 , 那么. 这个临时表依然是可以使用的 ,并且这个表对其他的SparkSession共享
filter
Show the distinct VOTER_NAME entries
voter_df.select(voter_df['VOTER_NAME']).distinct().show(40, truncate=False)
Filter voter_df where the VOTER_NAME is 1-20 characters in length
voter_df = voter_df.filter('length(VOTER_NAME) > 0 and length(VOTER_NAME) < 20')
Filter out voter_df where the VOTER_NAME contains an underscore
voter_df = voter_df.filter(~ F.col('VOTER_NAME').contains('_'))
Show the distinct VOTER_NAME entries again
voter_df.select('VOTER_NAME').distinct().show(40, truncate=False)
数据框的列操作
.split(),
.size(),
.getItem().
withColumn
类似于R中的mutate
可以新增列,第一个参数是列名,第二个参数是产生的列的值
# Add a column to voter_df for any voter with the title **Councilmember**
voter_df = voter_df.withColumn('random_val',
when(voter_df.TITLE == 'Councilmember', F.rand()))
# Show some of the DataFrame rows, noting whether the when clause worked
voter_df.show()
<script.py> output:
+----------+-------------+-------------------+-------------------+
| DATE| TITLE| VOTER_NAME| random_val|
+----------+-------------+-------------------+-------------------+
|02/08/2017|Councilmember| Jennifer S. Gates| 0.272243504035671|
|02/08/2017|Councilmember| Philip T. Kingston| 0.583853158237357|
|02/08/2017| Mayor|Michael S. Rawlings| null|
|02/08/2017|Councilmember| Adam Medrano|0.14137215411622484|
|02/08/2017|Councilmember| Casey Thomas| 0.9411379198657572|
|02/08/2017|Councilmember|Carolyn King Arnold| 0.8379601212162058|
|02/08/2017|Councilmember| Scott Griggs|0.18946575456658876|
|02/08/2017|Councilmember| B. Adam McGough| 0.4952465558048347|
|02/08/2017|Councilmember| Lee Kleinman| 0.8047711324429991|
|02/08/2017|Councilmember| Sandy Greyson| 0.719737910435363|
|02/08/2017|Councilmember| Jennifer S. Gates| 0.7558084225784659|
|02/08/2017|Councilmember| Philip T. Kingston| 0.7274572656632454|
|02/08/2017| Mayor|Michael S. Rawlings| null|
|02/08/2017|Councilmember| Adam Medrano| 0.68576094369742|
|02/08/2017|Councilmember| Casey Thomas|0.32803656527818037|
|02/08/2017|Councilmember|Carolyn King Arnold|0.24756136511724325|
|02/08/2017|Councilmember| Rickey D. Callahan| 0.8197696131561757|
|01/11/2017|Councilmember| Jennifer S. Gates| 0.2346485886453783|
|04/25/2018|Councilmember| Sandy Greyson| 0.4073895538345208|
|04/25/2018|Councilmember| Jennifer S. Gates| 0.1938743267054036|
+----------+-------------+-------------------+-------------------+
only showing top 20 rows
when/otherwise
# Add a column to voter_df for a voter based on their position
voter_df = voter_df.withColumn('random_val',
when(voter_df.TITLE == 'Councilmember', F.rand())
.when(voter_df.TITLE == 'Mayor', 2)
.otherwise(0))
# Show some of the DataFrame rows
voter_df.show()
# Use the .filter() clause with random_val
voter_df.filter(voter_df.random_val == 0).show()
when。。otherwise就类似于按照条件查找if else的形式
# Add a column to voter_df for a voter based on their position
voter_df = voter_df.withColumn('random_val',
when(voter_df.TITLE == 'Councilmember', F.rand())
.when(voter_df.TITLE == 'Mayor', 2)
.otherwise(0))
# Show some of the DataFrame rows
voter_df.show()
# Use the .filter() clause with random_val
voter_df.filter(voter_df.random_val == 0).show()
用户自定义函数
除了可以调内置函数之外,还可以自定义函数,这个其实在任何语言里都是一样的,函数嘛,就是实现功能 的
def getFirstAndMiddle(names):
# Return a space separated string of names
return ' '.join(names[:-1])
# Define the method as a UDF
udfFirstAndMiddle = F.udf(getFirstAndMiddle, StringType())
# Create a new column using your UDF
voter_df = voter_df.withColumn('first_and_middle_name', udfFirstAndMiddle(voter_df.splits))
# Show the DataFrame
voter_df.show()
Partitioning and lazy processing
# Select all the unique council voters
voter_df = df.select(df["VOTER NAME"]).distinct() #这个是去重的
# Count the rows in voter_df
print("\nThere are %d rows in the voter_df DataFrame.\n" % voter_df.count())
# Add a ROW_ID
voter_df = voter_df.withColumn('ROW_ID', F.monotonically_increasing_id())
# Show the rows with 10 highest IDs in the set
voter_df.orderBy(voter_df.ROW_ID.desc()).show(10)
<script.py> output:
There are 36 rows in the voter_df DataFrame.
+--------------------+-------------+
| VOTER NAME| ROW_ID|
+--------------------+-------------+
| Lee Kleinman|1709396983808|
| the final 201...|1700807049217|
| Erik Wilson|1700807049216|
| the final 20...|1683627180032|
| Carolyn King Arnold|1632087572480|
| Rickey D. Callahan|1597727834112|
| the final 2...|1443109011456|
| Monica R. Alonzo|1382979469312|
| Lee M. Kleinman|1228360646656|
| Jennifer S. Gates|1194000908288|
+--------------------+-------------+
only showing top 10 rows
展示分区数量的
# Print the number of partitions in each DataFrame
print("\nThere are %d partitions in the voter_df DataFrame.\n" % voter_df.rdd.getNumPartitions()) # 展示分区数量的
print("\nThere are %d partitions in the voter_df_single DataFrame.\n" % voter_df_single.rdd.getNumPartitions())
# Add a ROW_ID field to each DataFrame
voter_df = voter_df.withColumn('ROW_ID', F.monotonically_increasing_id())
voter_df_single = voter_df_single.withColumn('ROW_ID', F.monotonically_increasing_id())
# Show the top 10 IDs in each DataFrame
voter_df.orderBy(voter_df.ROW_ID.desc()).show(10)
voter_df_single.orderBy(voter_df_single.ROW_ID.desc()).show(10)
<script.py> output:
There are 200 partitions in the voter_df DataFrame.
There are 1 partitions in the voter_df_single DataFrame.
+--------------------+-------------+
| VOTER NAME| ROW_ID|
+--------------------+-------------+
| Lee Kleinman|1709396983808|
| the final 201...|1700807049217|
| Erik Wilson|1700807049216|
| the final 20...|1683627180032|
| Carolyn King Arnold|1632087572480|
| Rickey D. Callahan|1597727834112|
| the final 2...|1443109011456|
| Monica R. Alonzo|1382979469312|
| Lee M. Kleinman|1228360646656|
| Jennifer S. Gates|1194000908288|
+--------------------+-------------+
only showing top 10 rows
+--------------------+------+
| VOTER NAME|ROW_ID|
+--------------------+------+
| Lee Kleinman| 35|
| the final 201...| 34|
| Erik Wilson| 33|
| the final 20...| 32|
| Carolyn King Arnold| 31|
| Rickey D. Callahan| 30|
| the final 2...| 29|
| Monica R. Alonzo| 28|
| Lee M. Kleinman| 27|
| Jennifer S. Gates| 26|
+--------------------+------+
only showing top 10 rows
.rdd
RDD 是 Spark 提供的最重要的抽象概念,它是一种有容错机制的特殊数据集合,可以分布在集群的结点上,以函数式操作集合的方式进行各种并行操作。
cache
缓存机制
start_time = time.time()
# Add caching to the unique rows in departures_df
departures_df = departures_df.distinct().cache()
# Count the unique rows in departures_df, noting how long the operation takes
print("Counting %d rows took %f seconds" % (departures_df.count(), time.time() - start_time)) #计算加载的时间
# Count the rows again, noting the variance in time of a cached DataFrame
start_time = time.time()
print("Counting %d rows again took %f seconds" % (departures_df.count(), time.time() - start_time))
可以查看是否在缓存里面
查看缓存的时间
# Determine if departures_df is in the cache
print("Is departures_df cached?: %s" % departures_df.is_cached)
print("Removing departures_df from cache")
# Remove departures_df from the cache
departures_df.unpersist()
# Check the cache status again
print("Is departures_df cached?: %s" % departures_df.is_cached)
read config
# Name of the Spark application instance
app_name = spark.conf.get('spark.app.name')
# Driver TCP port
driver_tcp_port = spark.conf.get('spark.driver.port')
# Number of join partitions
num_partitions = spark.conf.get('spark.sql.shuffle.partitions')
# Show the results
print("Name: %s" % app_name)
print("Driver TCP port: %s" % driver_tcp_port)
print("Number of partitions: %s" % num_partitions)
<script.py> output:
Name: pyspark-shell
Driver TCP port: 45583
Number of partitions: 200
wirte config
# Store the number of partitions in variable
before = departures_df.rdd.getNumPartitions()
# Configure Spark to use 500 partitions
spark.conf.set('spark.sql.shuffle.partitions', 500)
# Recreate the DataFrame using the departures data file
departures_df = spark.read.csv('departures.txt.gz').distinct()
# Print the number of partitions for each instance
print("Partition count before change: %d" % before)
print("Partition count after change: %d" % departures_df.rdd.getNumPartitions())
join data
# Join the flights_df and aiports_df DataFrames
normal_df = flights_df.join(airports_df, \
flights_df["Destination Airport"] == airports_df["IATA"] )
# Show the query plan
normal_df.explain()
dataframe这里有非常多的内置函数,可以参考这里csdn
一些常用的函数,还是需要知道的,这样可以省力
using broadcasting
# Import the broadcast method from pyspark.sql.functions
from pyspark.sql.functions import broadcast
# Join the flights_df and airports_df DataFrames using broadcasting
broadcast_df = flights_df.join(broadcast(airports_df), \
flights_df["Destination Airport"] == airports_df["IATA"] )
# Show the query plan and compare against the original
broadcast_df.explain()
<script.py> output:
== Physical Plan ==
*(2) BroadcastHashJoin [Destination Airport#12], [IATA#29], Inner, BuildRight
:- *(2) Project [Date (MM/DD/YYYY)#10, Flight Number#11, Destination Airport#12, Actual elapsed time (Minutes)#13]
: +- *(2) Filter isnotnull(Destination Airport#12)
: +- *(2) FileScan csv [Date (MM/DD/YYYY)#10,Flight Number#11,Destination Airport#12,Actual elapsed time (Minutes)#13] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:/tmp/tmpvcaouj4c/AA_DFW_2018_Departures_Short.csv.gz], PartitionFilters: [], PushedFilters: [IsNotNull(Destination Airport)], ReadSchema: struct<Date (MM/DD/YYYY):string,Flight Number:string,Destination Airport:string,Actual elapsed ti...
+- BroadcastExchange HashedRelationBroadcastMode(List(input[1, string, true]))
+- *(1) Project [AIRPORTNAME#28, IATA#29]
+- *(1) Filter isnotnull(IATA#29)
+- *(1) FileScan csv [AIRPORTNAME#28,IATA#29] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:/tmp/tmpvcaouj4c/airportnames.txt.gz], PartitionFilters: [], PushedFilters: [IsNotNull(IATA)], ReadSchema: struct<AIRPORTNAME:string,IATA:string>
流水线式处理数据
# Import the data to a DataFrame
departures_df = spark.read.csv('2015-departures.csv.gz', header=True)
# Remove any duration of 0
departures_df = departures_df.filter(departures_df[3] > 0)
# Add an ID column
departures_df = departures_df.withColumn('id', F.monotonically_increasing_id())
# Write the file out to JSON format
departures_df.write.json('output.json', mode='overwrite')
## 一些数据处理得技巧
```r
# Import the file to a DataFrame and perform a row count
annotations_df = spark.read.csv('annotations.csv.gz', sep='|')
full_count = annotations_df.count()
# Count the number of rows beginning with '#'
comment_count = annotations_df.where(col('_c0').startswith('#')).count()
# Import the file to a new DataFrame, without commented rows
no_comments_df = spark.read.csv('annotations.csv.gz', sep='|', comment='#')
# Count the new DataFrame and verify the difference is as expected
no_comments_count = no_comments_df.count()
print("Full count: %d\nComment count: %d\nRemaining count: %d" % (full_count, comment_count, no_comments_count))
<script.py> output:
Full count: 32794
Comment count: 1416
Remaining count: 31378
删除无效得行
Removing invalid rows
# Split _c0 on the tab character and store the list in a variable
tmp_fields = F.split(annotations_df['_c0'], '\t')
# Create the colcount column on the DataFrame
annotations_df = annotations_df.withColumn('colcount', F.size(tmp_fields))
# Remove any rows containing fewer than 5 fields
annotations_df_filtered = annotations_df.filter(~ (annotations_df["colcount"] < 5))
# Count the number of rows
final_count = annotations_df_filtered.count()
print("Initial count: %d\nFinal count: %d" % (initial_count, final_count))
划分数据集
Split the content of _c0 on the tab character (aka, '\t')
split_cols = F.split(annotations_df["_c0"], '\t')
Add the columns folder, filename, width, and height
split_df = annotations_df.withColumn('folder', split_cols.getItem(0))
split_df = split_df.withColumn('filename', split_cols.getItem(1))
split_df = split_df.withColumn('width', split_cols.getItem(2))
split_df = split_df.withColumn('height', split_cols.getItem(3))
Add split_cols as a column
split_df = split_df.withColumn('split_cols', split_cols)