Databricks

Spark を使ってみる

公開日: 2023-10-23 更新日: 2024-09-03

アーキテクチャ

ノード

  • Driver … 分散しない処理とExecutorの制御を行う
  • Executor … 分散処理を行うノード

処理の構造

  • ジョブ … 複数のステージを内包。処理するファイルの文だけ作られる
  • ステージ … 複数のタスクを内包。単一ノード内で完結するタスクをまとめたもの
  • タスク

データを入出力する

  • RDD … Spark がデータを扱う際の素のAPI.
  • DataFrame … pandas に似た API が提供されており扱いやすい.
# -----------------------------
# RDD
# SparkContext を通して扱う
# -----------------------------
def proc(val):
    return val * 2

dt_lst = [1,2,3,"aaa"]
dt_rdd = spark.sparkContext.parallelize(dt_lst)
dt_rdd.map(proc).collect()

# -----------------------------
# DataFrame 
# SparkSession を通して扱う。
# このオブジェクトは RDD/Hive/SQL を束ねて扱うための仕組みとして提供されている。
# -----------------------------

# 関数から生成
dt_df = spark.range(0, 5)

# 配列から生成
dt_df = spark.createDataFrame([[1,2,3]], ["col1", "col2", "col3"])

# フィールドの情報を取得
print(df.schema)
df.printSchema()

# 取得した DataFrame の内容を表示
dt_df.display()
display(dt_df)

読み書き時 (基本)

SparkSQL での書き方

-- ファイルパスを指定して永続テーブル化する (SQL)
DROP TABLE IF EXISTS diamonds;
CREATE TABLE diamonds
USING CSV OPTIONS (path "/databricks-datasets/Rdatasets/data-001/csv/ggplot2/diamonds.csv", header "true");

-- ファイルパスを指定して一時テーブル化する (SQL)
DROP TABLE IF EXISTS diamonds_tmp;
CREATE TEMP TABLE diamonds_tmp
USING CSV OPTIONS (path "/databricks-datasets/Rdatasets/data-001/csv/ggplot2/diamonds.csv", header "true");

-- ファイルパスを指定してグローバル一時ビュー化する (SQL)
DROP VIEW IF EXISTS diamonds_tmp;
CREATE GLOBAL TEMP VIEW diamonds_tmp
USING CSV OPTIONS (path "/databricks-datasets/Rdatasets/data-001/csv/ggplot2/diamonds.csv", header "true");
SELECT * FROM global_temp.diamonds_tmp;

-- 指定したテーブルへデータロード
COPY INTO diamonds
FROM '/databricks-datasets/Rdatasets/data-001/csv/ggplot2/diamonds.csv' FILEFORMAT = CSV
FORMAT_OPTIONS ('inferSchema'='true', 'header'='true')

-- ファイパスを指定してロードする
SELECT *
FROM read_files(
    'dbfs:/databricks-datasets/Rdatasets/data-001/csv/ggplot2/diamonds.csv',
    format => 'csv', header => true)
LIMIT 10

-- format.`path` でもロードできる
SELECT * FROM binaryFile.`path\`
SELECT * FROM json.`path\`
SELECT * FROM text.`path\`
SELECT * FROM csv.`path\`

Python での書き方

from pyspark.sql.types import *

# テーブルから生成
dt_df = spark.table("samples.tpch.customer")

# ファイルパスを指定して読み込む (Python)
df = spark.read \
    .options(header=True, inferSchema=True) \
    .csv("/databricks-datasets/retail-org/customers/")

df = spark.read \
    .format("csv") \
    .options(header=True, inferSchema=True) \
    .load("/databricks-datasets/Rdatasets/data-001/csv/ggplot2/diamonds.csv") \
    .limit(10)

# 一時テーブルを生成
df.createOrReplaceTempView("diamonds_tmp")
# 永続テーブルを生成
df.write.saveAsTable("diamonds")

# 構造化ストリーミングで読み込み
# ストリーミングで読み込む場合にはスキーマ (後述) の指定が必要
df = spark.readStream \
    .format("csv") \
    .options(header=True) \
    .schema(
        StructType([
            StructField('customer_id', IntegerType(), True),
            StructField('tax_id', DoubleType(), True),
            StructField('tax_code', StringType(), True),
            StructField('customer_name', StringType(), True),
            StructField('state', StringType(), True),
            StructField('city', StringType(), True),
            StructField('postcode', StringType(), True),
            StructField('street', StringType(), True),
            StructField('number', StringType(), True),
            StructField('unit', StringType(), True),
            StructField('region', StringType(), True),
            StructField('district', StringType(), True),
            StructField('lon', DoubleType(), True),
            StructField('lat', DoubleType(), True),
            StructField('ship_to_address', StringType(), True),
            StructField('valid_from', IntegerType(), True),
            StructField('valid_to', DoubleType(), True),
            StructField('units_purchased', DoubleType(), True),
            StructField('loyalty_segment', IntegerType(), True)
        ])
    ) \
    .load("/databricks-datasets/retail-org/customers/")
display(df)

読み書き時 (色々)

色々なオプションが提供されている。

PropDescDefault
sep区切り文字.カンマ
encoding文字コードUTF-8
header先頭行から項目名を取得するかFalse
inferSchema1度読み込んでスキーマを推測した後に改めて読み込む. 推測しない場合は全て文字列型になる. date型は未サポートFalse

SQL での書き方。

CREATE OR REPLACE TABLE sample_tbl AS
SELECT * FROM (VALUES (1, "aaa"), (2, "bbb"), (3, "ccc")) AS (col_a, col_b);
CREATE OR REPLACE TEMP VIEW sample_new AS
SELECT * FROM (VALUES (3, "ddd"), (4, "eee")) AS (col_a, col_b);

-- 上書き
-- 新規やスキーマ変更は不可
INSERT OVERWRITE sample_tbl SELECT * FROM sample_new;
-- 新規やスキーマ変更も可能
CREATE OR REPLACE TABLE sample_tbl AS SELECT * FROM sample_new;

-- マージ
MERGE INTO sample_tbl a USING sample_new b ON a.col_a = b.col_a
WHEN MATCHED AND a.col_b != b.col_b THEN UPDATE SET a.col_b = b.col_b
WHEN NOT MATCHED THEN INSERT *

-- テーブルクローン  
-- シャローは複製元のファイルを参照する一方、ディープは全データをコピーする。
CREATE TABLE sample_tbl_sc SHALLOW CLONE sample_tbl;
CREATE TABLE sample_tbl_dc DEEP CLONE sample_tbl;

-- Generated Column
-- 値を自動で導出する
CREATE OR REPLACE TABLE sample_tbl (
  col_a INTEGER,
  col_b INTEGER GENERATED ALWAYS AS (col_a *2)
);
INSERT INTO sample_tbl(col_a) VALUES (1), (2)
SELECT * FROM sample_tbl

Python での書き方。

# 参考: https://spark.apache.org/docs/latest/sql-data-sources-csv.html
spark.read \
    .option("inferSchema", True) \
    .options(inferSchema=True) \
    .csv(path, inferSchema=True)

# 参考: https://spark.apache.org/docs/latest/sql-data-sources-parquet.html
# オプションの compression は parquet で書き出す際に使用できるオプション。  
spark.write \
    .option("compression", "snappy") \
    .mode("append|overwrite|error|ignore") \
    .parquet(path)

スキーマ

DataFrame は項目定義の情報としてスキーマを要する。
これはソースデータから推定する形で生成されるが、明示も可能。ストリーミングを使う際など推定が出来ない場合では明示することになる。

from pyspark.sql.types import *

# 配列から生成 (スキーマ明示)
dt_df = spark.createDataFrame(
    [
        (1, {"prop1": 111, "prop2": "hoge1"}),
        (2, {"prop1": 222, "prop2": "hoge2"}),
        (3, {"prop1": 333, "prop2": "hoge3"}),
    ],
    StructType([
        StructField("RecordID", IntegerType(), False),
        StructField(
            "Info",
            StructType([
                StructField("prop1", IntegerType(), True),
                StructField("prop2", StringType(), True)
            ])
        )
    ])
)

スキーマは変更できる (スキーマ進化、スキーマ上書き)。

spark.write \
    .option("compression", "snappy") \
    # スキーマを変更
    .option("mergeSchema|overwriteSchema", "true")
    .mode("append|overwrite|error|ignore") \
    .parquet(path)

スキーマ周りの tips

# DataFrame のスキーマ情報からスキーマ定義に使用できる項目定義テキストを取得
events_df._jdf.schema().toDDL()

# 項目定義テキストから型オブジェクトを生成
pyspark.sql.types._parse_datatype_string(string)

テーブル情報の取得

テーブルやビューの情報を見る際は DESCRIBE を使う。 メタデータを見たい場合は EXTENDED を付与する。保存場所が見たい場合は DETAIL を付与する

CREATE TEMP VIEW sample_vw AS (SELECT * FROM (VALUES (1,2,3), (1,2,3)) AS (col_a, col_b, col_c));
DESCRIBE EXTENDED SAMPLE_VW;
DESCRIBE DETAIL SAMPLE_VW;

-- Databricks で Delta Lake を使用していると使用可能
DESCRIBE HISTORY SAMPLE_VW;

-- 各種オブジェクトの一覧を表示
SHOW TABLES;
SHOW SCHEMAS;
SHOW CATALOGS;
SHOW CREATE TABLE tablename;

テーブルのプロパティとオプション

-- ビューやテーブルにカスタムプロパティを付与する
CREATE OR REPLACE VIEW sample_vw TBLPROPERTIES ('prop_1'=123, 'prop_2'=234) AS (
    SELECT * FROM (VALUES (1,2,3), (1,2,3)) AS (col_a, col_b, col_c)
);
-- ビューやテーブルにオプションを付与する
-- https://learn.microsoft.com/ja-jp/azure/databricks/sql/language-manual/sql-ref-syntax-ddl-tblproperties
-- 一般的なオプション
-- * delta.appendOnly
-- * delta.dataSkippingNumIndexedCols
-- * delta.deletedFileRetentionDuration
-- * delta.logRetentionDuration
CREATE OR REPLACE TABLE sample_tbl(col_a integer, col_b integer) OPTIONS (prop_1=123, prop_2=234);

ALTER VIEW sample_vw SET TBLPROPERTIES ('prop_3'='abc');
ALTER VIEW sample_vw SET TBLPROPERTIES ('prop_4'='bcd');
ALTER VIEW sample_vw UNSET TBLPROPERTIES ('prop_4');

SHOW TBLPROPERTIES sample_vw

テーブルメンテナンス

OPTIMIZE sample_tbl ZORDER BY col_a
VACUUM sample_tbl [RETAIN 30 HOURS]

データをクエリする

クエリ作成

SQL で使用するワードを反映したメソッドを呼び出すことで集計できる。 複雑な条件を定義したい場合は項目に対し pyspark.sql.Column クラスのインスタンスを生成し、クラスのメソッドを使うことで定義できる。

from pyspark.sql.functions import col, when, otherwise

spark.table("products") \
    .where(col("price") < 200) \
    .orderBy(col("price").desc(), "name") \
    .select(
        "name",
        col("price").cast("string"),
        when(col("price") > 10, 1).otherwise(2)
    ) \
    .display()

# 項目の取得には幾つかの方法がある。
df.columnName
df["columnName"]
col("columnName")

Column クラスを用いる事で定義できる集計式は以下の通り。以下に幾つか列挙する。
https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/column.html

MethodDescription
*, + , <, >=, ==, !=等価・比較演算子
alias列にエイリアスを与えます
cast, astype列を異なるデータ型にキャスト
isNull, isNotNull, isNannull、非null、NaNの判断
asc, desc列の昇順/降順

Column クラスの生成は pyspark.sql.functions でメソッドが多く提供されている。
CASE 文などもこれを用いて表現する。
https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/functions.html

from pyspark.sql.functions import col, when

df.withColumn("hoge",
    when(col("foo") < 100, "value1").
    when(col("foo") < 200, "value2").
    otherwise("value3")
)
MethodDescription
ceil指定された列を正の値の場合は切上げ、負の値の場合は切り下げする
cos指定された値の余弦を計算する
log指定された値の自然対数を計算する
roundHALF_UP丸めモードで0桁までの列eの値を返する
sqrt指定された浮動小数点値の平方根を計算する

グループ化したい場合は以下のようにする。

import pyspark.sql.functions as F

# グループ化
df.groupBy("columnName")
df.groupBy("columnName1", "columnName2")
df.groupBy(["columnName1", "columnName2"])

# グループ集計
df.groupBy("columnName").sum("sumColumn1", "sumColumn2")
df.groupBy("columnName").agg(
    F.sum("sumColumn1").alias("sumColumn1"),
    F.avg("avgColumn2").alias("avgColumn2")
)

日付の扱いも pyspark.sql.functions に便利なメソッドが提供されている。

df = spark \
    .createDataFrame([("2023-11-05 21:00:00",)], ["date_text"]) \
    .select(
        F.to_utc_timestamp(F.to_timestamp("date_text", "yyyy-MM-dd HH:mm:ss"), "Asia/Tokyo").alias("date_dt"),
        F.dayofweek(F.current_timestamp()),
        F.unix_timestamp(F.col("date_text"), "yyyy-MM-dd HH:mm:ss").alias("unixtime")
    ) \
    .display()
MethodDescription
add_months(col: Column, months: int): Columnmonths ヶ月足す
current_timestamp(): Column現在の日時
date_format(col: Column, format: str): Column日付/タイムスタンプ/文字列を文字に変換
dayofweek(col: Column)日曜~月曜を1~7の数値に変換
from_unixtime(col: Column, format: str)Unixタイムを format 形式の文字列へ変換
minute(col: Column)分を数値として抽出
unix_timestamp(col: Column, format: str)format 形式で表現された日付を Unix タイムに変換

pyspark.sql.functions には他にも便利な関数がある。

MethodDescription
approx_count_distinct一意のアイテムの近似数を返す
avg平均値を返す
collect_list重複を含むオブジェクトのリストを返す
corr2つの列のピアソン相関係数を返す
max最大値を計算する
mean平均値を計算する
stddev_sampサンプル標準偏差を返す
sumDistinct式内の一意の値の合計を返す
var_pop母集団分散を返す
split項目の値を引数の値で分割する

項目を新たに作ったり、名前を変更する場合は以下の通り。

# 項目作成
dt_df.withColumn("columnName", col: Column)
dt_df.withColumnRename("oldColumnName", "newColumnName")

# 項目削除
dt_df.drop("columnName")

# データフレーム統合
# union all 相当. 統合対象のスキーマは順序レベルで同一にする必要があるが、ByName の場合は項目名で突合される
dt_df.union(df_df2)
dt_df.unionByName(df_df2)

欠落データに対するメソッドをまとめたものとして pyspark.sql.DataFrameNaFunctions もある。

# null 行の削除
dt_df.na.drop(
    how: str = 'any(default)|all',
    thresh: Optional[int] = None,
    subset: Union[str, List[str], None] = None
)
# null 列へ代替値を代入
dt_df.na.fill(value, subset: Optional[List[str]])
# 任意の値へ代替値を代入. dt_df.replace のエイリアス
dt_df.na.replace(newvalue, oldvalue, subset: Optional[List[str]])

テーブル結合の操作は以下の通り。

joined_df = dt_df.join(other=dt_df2, on='columnName', how = "inner")

ネストデータの扱いは以下の通り。

import pyspark.sql.types as T
import pyspark.sql.functions as F

# リストオブジェクトを定義したデータフレームを定義
dt_df = spark.createDataFrame(
    [
        (1, [{"prop1": 111, "prop2": "hoge1"}, {"prop1": 222, "prop2": "hoge2"}]),
        (2, [{"prop1": 444, "prop2": "hoge1"}, {"prop1": 555, "prop2": "hoge2"}]),
        (3, [{"prop1": 777, "prop2": "hoge1"}, {"prop1": 888, "prop2": "hoge2"}]),
    ],
    T.StructType([
        T.StructField("RecordID", T.IntegerType(), False),
        T.StructField(
            "Info",
            T.ArrayType(
                T.StructType([
                    T.StructField("prop1", T.IntegerType(), True),
                    T.StructField("prop2", T.StringType(), True)
                ])
            )
        )
    ])
)
dt_df.display()

# リストオブジェクトを展開
dt_df = dt_df.withColumn("Info", F.explode("Info"))
dt_df.display()

# 行展開をリストオブジェクトへ戻す
dt_df = dt_df.groupBy("RecordID").agg(F.collect_list("Info"))
dt_df.display()
MethodDescription
explode(col: Column)リストオブジェクトを行方向へ展開
collect_list(col: Column)グループ化された項目をリストオブジェクトへ変換 (重複有り)
collect_set(col: Column)グループ化された項目をリストオブジェクトへ変換 (重複無し)
flatten(col: Column)複数の配列を単一の配列に結合

サブフィールドを参照

  • Object:Field … JSON として Field を参照
  • Object.Field … StructType として StructField を参照

スキーマを扱う

import pyspark.sql.functions as F
import pyspark.sql.types as T

# JSON からスキーマ情報を生成]
# schema = T.StructType([T.StructField("hoge", T.IntegerType())])
schema = F.schema_of_json("{'hoge': 123}")
# JSON 文字列と JSON Schema から StructType を生成
spark \
    .createDataFrame([["{'hoge': 123}"]], "val string") \
    .withColumn("hoge", F.from_json("val", schema)) \
    .display()

ピボット

CREATE OR REPLACE temp view sample_tbl AS
  SELECT * FROM (
    VALUES
      ('grp_a',1),
      ('grp_a',2),
      ('grp_b',3),
      ('grp_c',4),
      ('grp_c',5)
  ) AS (dimention, measure);

SELECT * FROM sample_tbl PIVOT (sum(measure) FOR dimention IN ("grp_a", "grp_b", "grp_c"))
spark.table("sample_tbl") \
    .groupBy("dimention") \
    .pivot("dimention") \
    .sum() \
    .display()

クエリ実行

Spark では集計処理を書いただけでは処理は行われない。遅延評価を実行するメソッドを呼び出すことで処理が実行される。

MethodDescription
collect, first, head, take全/最初の1/最初のn行を返す
count行数を返す
show上位n行を表形式で表示
describe, summary数値および文字列の列に対する基本統計を計算

キャッシュ機能は以下の通り

  • DataFrame.cache
  • DataFrame.unpersist

ユーザ定義関数

SQL で定義する場合は以下の通り。定義・実行には以下の権限を要する。

TypeUSE CATALOGUSE SCHEMACREATE FUNCTIONEXECUTE
createyesyesyesno
executeyesyesnoyes
-- 関数定義
CREATE OR REPLACE FUNCTION hoge_func(val string) RETURNS string
RETURN val || "aaa";

-- 情報取得
DESC FUNCTION extended hoge_func

-- 関数実行
SELECT hoge_func("hoge");

python 上で使いたい場合は以下のように定義する。

import pandas as pd
import pyspark.sql.types as T
import pyspark.sql.functions as F
from typing import Iterator, Tuple

# Scalar to Scalar 型
@F.udf("int")
def sample_s2s(val: str) -> int:
    return int(val)

# Vector to Vector 型 (即時評価型)
@F.pandas_udf("int")
def sample_v2v1(val: pd.Series) -> pd.Series:
    return val.astype(int)

# Vector to Vector 型 (遅延評価型)
@F.pandas_udf("int")
def sample_v2v2(vals: Iterator[pd.Series]) -> Iterator[pd.Series]:
    return [v + 1 for v in vals]

# Matrics to Vector 型 (遅延評価型)
@F.pandas_udf("int")
def sample_m2v(vals: Iterator[Tuple[pd.Series, pd.Series]]) -> Iterator[pd.Series]:
    return [v1 + v2 for v1,v2 in vals]

# Vector to Scalar 型: 
@F.pandas_udf("int")
def sample_v2s(val: pd.Series) -> int:
    return val.sum()

dat_df = spark.createDataFrame(
    [
        ["111", "222", 111, 222, "aaa"],
        ["444", "555", 444, 555, "aaa"]
    ],
    T.StructType([
        T.StructField("prop1", T.StringType(), False),
        T.StructField("prop2", T.StringType(), False),
        T.StructField("prop3", T.IntegerType(), False),
        T.StructField("prop4", T.IntegerType(), False),
        T.StructField("prop5", T.StringType(), False)
    ])
)

dat_df \
    .withColumn("sample_s2s", sample_s2s("prop1")) \
    .withColumn("sample_v2v1", sample_v2v1("prop2")) \
    .withColumn("sample_v2v2", sample_v2v2("prop3")) \
    .withColumn("sample_m2v", sample_m2v("prop3", "prop4")) \
    .show()
    # .printSchema()

dat_df.select(sample_v2s("prop3")) \
    .show()
    # .printSchema()

相互運用

DataFrame と SQL

# DataFrame を SparkSQL から生成
dt_df = spark.sql("SELECT * FROM samples.tpch.customer")

# DataFrame を SparkSQL から参照できるよう登録する
dt_df.createOrReplaceTempView("hoge")

# SparkSQL で登場する式を使用して各値を取り出せる
dt_df.selectExpr("expr1", "expr2" ...)

# Python UDF を SQL 側で使えるようにする
dat_df.createOrReplaceTempView("dat_df")
spark.udf.register("sample_vcudf", sample_vcudf)
spark.sql("select * from dat_df limit 10").display()

DataFrame と Pandas

dt_df = spark.createDataFrame([[1,2,3]], ["col1", "col2", "col3"])

# Pandas DataFrame 風インターフェイスへ変換
pd_df1 = dt_df.toPandas()

# Pandas DataFrame へ変換
pd_df2 = dt_df.pandas_api()