Notice
Recent Posts
Recent Comments
«   2025/01   »
1 2 3 4
5 6 7 8 9 10 11
12 13 14 15 16 17 18
19 20 21 22 23 24 25
26 27 28 29 30 31
Archives
Today
Total
관리 메뉴

SYDev

[Spark The Definitive Guide] Part 6-2. 고급 분석과 머신러닝(ch 29 ~ 31) 본문

KHUDA 5th/Data Engineering

[Spark The Definitive Guide] Part 6-2. 고급 분석과 머신러닝(ch 29 ~ 31)

시데브 2024. 4. 11. 02:08
데이터 엔지니어링 심화트랙 5주차 정리 내용

 

 

Chapter 29. 비지도 학습

29.1. 활용 사례

  • 데이터 이상치 탐지: 데이터셋 내 다수의 값이 하나의 큰 그룹으로 군집화되고, 나머지 값은 몇몇 소그룹으로 군집화되는 경우 해당 소그룹을 추가 조사하여 이상치를 탐지할 수 있다.
  • 토픽 모델링: 많은 양의 텍스트 문서를 사전에 학습하여 서로 다른 텍스트 문서 사이의 공통적인 주체를 도출할 수 있다.

 

29.2. 모델 확장성

 

 

29.3 k-평균

>>> from pyspark.ml.feature import VectorAssembler
>>> va = VectorAssembler()\
... .setInputCols(["Quantity", "UnitPrice"])\
... .setOutputCol("features")
>>> sales = va.transform(spark.read.format("csv")
... .option("header", "true")
... .option("inferSchema", "true")
... .load("/Users/sangyeong_park/CE/KHUDA_5th/Data_Engineering/Spark-The-Definitive-Guide-master/data/retail-data/by-day/*.csv")
... .limit(50)
... .coalesce(1)
... .where("Description IS NOT NULL"))
>>> sales.cache()

29.3.1. 모델 하이퍼파라미터

  • k: 최종 생성하고자 하는 군집 수

29.3.2. 학습 파라미터

  • initMode: 군집 중심의 시작 위치를 결정하는 알고리즘. random과 k-means||가 제공됨
  • initSteps: k-means|| 초기화 모드의 단계 수. 기본값은 2이며, 0보다 커야 한다.
  • maxIter: 군집화를 수행할 총 반복 횟수
  • tol: 중심값의 변화가 모델이 충분히 최적화되었다는 것을 알려주는 임곗값을 지정하여 maxIter 값만큼 반복 수행을 하기 전에 조기 중지시킬 수 있다.

29.3.3. 실습 예제

>>> from pyspark.ml.clustering import KMeans
>>> km = KMeans().setK(5)
>>> print(km.explainParams())
>>> kmModel = km.fit(sales)

 

29.3.4. k-평균 평가지표 요약 정보

  • k-평균의 summary: 생성된 군집에 대한 정보와 상대적 크기 포함
  • computeCost: 군집내 오차제곱합 계산 -> 오차제곱합을 최소화해야 함
>>> summary = kmModel.summary
>>> print(summary.clusterSizes)
>>> centers = kmModel.clusterCenters()
>>> print("Cluster Centers: ")
>>> for center in centers:

 

29.4. 이분법 k-평균

  • 최초에 단일 그룹을 생성한 다음, 해당 그룹을 더 작은 그룹으로 나누고, 마지막에는 사용자가 지정한 수의 군집으로 끝남
  • 일반적으로 k-평균보다 더 빠르며 군집 결과도 차이가 있다.

29.4.2. 학습 파라미터

  • minDivisibleClusterSize: 군집으로 분류하기 위해 포함되어야 할 최소 데이터 수 혹은 최소 데이터 비율
  • maxIter

29.4.3. 실습 예제

>>> from pyspark.ml.clustering import BisectingKMeans
>>> bkm = BisectingKMeans().setK(5).setMaxIter(5)
>>> bkmModel = bkm.fit(sales)

 

29.4.4. 이분법 k-평균 요약 정보

>>> summary = bkmModel.summary
>>> print(summary.clusterSizes)
>>> kmModel.computeCost(sales)
>>> bkmModel.computeCost(sales)
>>> centers = bkmModel.clusterCenters()
>>> for center in centers
... 	print(center)

 

 

29.5. 가우시안 혼합 모델

  • 각 군집이 가우시안 분포로부터 무작위 추출을 하여 데이터를 생성한다고 가정
  • 생성된 군집 가장자리에 데이터가 포함될 확률이 낮고, 군집 중앙에 데이터 포함될 확률이 높아야 함
  • k-평균을 좀 더 유연하게 변형한 알고리즘

29.5.1,-2. 파라미터

  • k
  • maxIter
  • tol

29.5.3. 실습 예제

>>> from pyspark.ml.clustering import GaussianMixture
>>> gmm = GaussianMixture().setK(5)
>>> print(gmm.explainParams())
>>> model = gmm.fit(sales)

 

>>> summary = model.summary
>>> print(model.weights)
[0.4661198587898743, 0.07999987251969042, 0.2538801505161057, 0.060027692244960205, 0.13997242592936932]
>>> model.gaussiansDF.show()
+--------------------+--------------------+
|                mean|                 cov|
+--------------------+--------------------+
|[8.24510507364062...|11.82676926004507...|
|[19.9999941018815...|48.00004718500557...|
|[3.05968380638117...|1.114399808533063...|
|[43.9907106264962...|32.17245079845634...|
|[22.8569445957599...|3.265648219547842...|
+--------------------+--------------------+

>>> summary.cluster.show()
+----------+
|prediction|
+----------+
|         3|
|         4|
|         4|
|         4|
|         0|
|         3|
|         1|
|         4|
|         0|
|         0|
|         2|
|         0|
|         0|
|         0|
|         0|
|         0|
|         3|
|         4|
|         0|
|         0|
+----------+
only showing top 20 rows

>>> summary.clusterSizes
[25, 4, 11, 3, 7]
>>> summary.probability.show()
+--------------------+
|         probability|
+--------------------+
|[1.37588698714851...|
|[6.41394482309844...|
|[1.27632212498067...|
|[7.88260292980270...|
|[0.98409715736103...|
|[1.37404323485283...|
|[1.18148507737198...|
|[1.43628966658207...|
|[0.67519709563195...|
|[0.67519709563195...|
|[0.37835078668946...|
|[0.58309212870069...|
|[0.99999848058199...|
|[0.99999848058199...|
|[0.99999999925803...|
|[0.85778970699783...|
|[1.36458180149443...|
|[7.88260292980270...|
|[0.67519709563195...|
|[0.67519709563195...|
+--------------------+
only showing top 20 rows

 

29.6. 잠재 디리클레 할당

  • 잠재 디리클레 할당(Latent Dirichlet Allocation, LDA): 일반적으로 텍스트 문서에 대한 토픽 모델링을 수행하는 데 사용되는 계층적 군집화 모델. 
  • 주제와 관련된 일련의 문서와 키워드로부터 주제를 추출
  • 스파크에서 LDA를 구현하는 두 가지 방법 -> Online LDA, 기댓값 최적화(expectation maximization)
  • 온라인 LDA -> 샘플 데이터가 많은 경우 적합
  • 기댓값 최적화 -> 어휘 수가 많은 경우 적합
  • 텍스트 데이터를 LDA에 입력하려면 수치형으로 먼저 변환해야 함 -> CountVectorizer 사용

29.6.1. 하이퍼파라미터

  • k: 추론할 총 주제 수
  • docConcentration: 디리클레 분포 파라미터 -> LDA에서 문헌별 주제 분포를 결정(얼마나 밀집한지 혹은 희소한지)하는 파라미터
  • topicConcentraion: 주제가 가지는 단어 분포의 사전 추정치. 대칭 디리클레 분포 파라미터

29.6.2. 학습 파라미터

  • maxIter
  • optimizer: 기댓값 최적화 or Online LDA
  • learningDecay: 지수적 감쇠율
  • learningOffset: 초기 반복 수행 횟수를 줄이는 학습 파라미터. 값이 클수록 초기 반복 횟수가 감소
  • optimizeDocConcentration: 학습 과정에서 디리클레 파라미터인 docConcentration이 최적화될지 여부
  • subsamplingRate: 미니배치 하강법의 각 반복 수행에서 샘플링 및 적용되는 말뭉치의 비율
  • seed: LDA 모델의 재현성을 위해 임의의 시드 지정
  • checkpointInterval: 체크포인트 기능

29.6.3. 예측 파라미터

  • topicDistributionCol: 각 문서의 주제 혼합 분포의 결과를 출력하는 컬럼

29.6.4. 실습 예제

>>> from pyspark.ml.feature import Tokenizer, CountVectorizer
>>> tkn = Tokenizer().setInputCol("Description").setOutputCol("DescOut")
>>> tokenized = tkn.transform(sales.drop("features"))
>>> cv = CountVectorizer()\
... .setInputCol("DescOut")\
... .setOutputCol("features")\
... .setVocabSize(500)\
... .setMinTF(0)\
... .setMinDF(0)\
... .setBinary(True)
>>> cvFitted = cv.fit(tokenized)
>>> prepped = cvFitted.transform(tokenized)
>>> from pyspark.ml.clustering import LDA
>>> lda = LDA().setK(10).setMaxIter(5)
>>> print(lda.explainParams())
>>> model = lda.fit(prepped)
>>> model.describeTopics(3).show()
>>> cvFitted.vocabulary

 

  • model.logLikelihood, model.logPerplexity를 사용하여 로그 우도와 복잡도(기술적 평가지표)를 계산할 수 있다.

 

Chapter 30. 그래프 분석

  • 비방향성 그래프(undirected graph)
  • 방향성 그래프(directed graph)
  • 신용카드 사기 적발
  • 모티프 발견
  • 서지네트워크에서 특정 논문의 중요도 결정
  • 페이지랭크 알고리즘을 활용한 웹페이지 순위 결정 
>>> bikeStations = spark.read.option("header", "true")\                         
... .csv("/Users/sangyeong_park/CE/KHUDA_5th/Data_Engineering/Spark-The-Definitive-Guide-master/data/bike-data/201508_station_data.csv")
>>> tripData = spark.read.option("header", "true")\
... .csv("/Users/sangyeong_park/CE/KHUDA_5th/Data_Engineering/Spark-The-Definitive-Guide-master/data/bike-data/201508_trip_data.csv")

-> 자전거 여행 데이터

 

30.1 그래프 작성하기

  • 정점과 에지를 정의해야 하는데, 이들은 각각 별도 명명된 컬럼으로 표현되는 Dataframe
  • GraphFrame 라이브러리에서 제시하는 컬럼에 대한 명명규칙을 사용해야 한다.
  • 정점 테이블에서는 index를 id로 정의(문자열 타입), 에지 테이블에서는 각 에지의 시작 정점 ID를 src로, 도착 정점 ID를 dst로 표시
>>> stationVertices = bikeStations.withColumnRenamed("name", "id").distinct()
>>> tripEdges = tripData\
... .withColumnRenamed("Start Station", "src")\
... .withColumnRenamed("End Station", "dst")

 

>>> from graphframes import GraphFrame
>>> stationGraph = GraphFrame(stationVertices, tripEdges)
>>> stationGraph.cache()

 

-> import는 됐는데, 막상 GraphFrame 객체 생성이 안 돼서 찾아보니, 스파크 세션을 다음 명령어로 다시 시작해줘야 한다.

bin/pyspark --packages graphframes:graphframes:0.8.3-spark3.5-s_2.12

-> 내 경우에는 spark 버전이 3.5.1, scala 버전이 2.12.18이라서 위와 같이 설정함

https://stackoverflow.com/questions/68708862/py4jjavaerror-an-error-occurred-while-calling-o65-creategraph

 

-> 실행할 때, spark와 scala 버전 확인 가능!

 

>>> print("Total Number of Stations: " + str(stationGraph.vertices.count()))
>>> print("Total Number of Trips in Graph: " + str(stationGraph.edges.count()))                                      
>>> print("Total Number of Trips in Original Data: " + str(tripData.count()))

 

 

30.2. 그래프 쿼리하기

>>> from pyspark.sql.functions import desc
>>> stationGraph.edges.groupBy("src", "dst").count().orderBy(desc("count")).show(10)

 

>>> stationGraph.edges\
... .where("src = 'Townsend at 7th' OR dst = 'Townsend at 7th'")\               
... .groupBy("src", "dst").count()\
... .orderBy(desc("count"))\
... .show(10)

-> 특정 도착지를 기준으로 해당 지점에서의 출발과 도착 횟수를 계산

 

30.2.1. 서브그래프

  • 규모가 큰 그래프 안에서 형성되는 작은 규모의 그래프
  • 쿼리 기능을 사용하여 서브 그래프 생성
>>> townAnd7thEdges = stationGraph.edges\
...   .where("src = 'Townsend at 7th' OR dst = 'Townsend at 7th'")
>>> subgraph = GraphFrame(stationGraph.vertices, townAnd7thEdges)

 

30.3. 모티프 찾기

  • 구조적 패턴을 그래프로 표현하는 방법
  • 모티프를 지정하면 실제 데이터 대신 데이터의 패턴을 쿼리
>>> motifs = stationGraph.find("(a)-[ab]->(b); (b)-[bc]->(c); (c)-[ca]->(a)")

-> ab는 a에서 b로 가는 에지를 의미

-> 정점 a, b, c와 각 에제의 중첩 필드가 포함된 DataFrame이 생성됨

예제에서 구현한 삼각형 모티프

 

>>> from pyspark.sql.functions import expr
>>> motifs.selectExpr("*",
...     "to_timestamp(ab.`Start Date`, 'MM-dd-yyyy HH:mm') as abStart",
...     "to_timestamp(bc.`Start Date`, 'MM-dd-yyyy HH:mm') as bcStart",
...     "to_timestamp(ca.`Start Date`, 'MM-dd-yyyy HH:mm') as caStart")\
...   .where("ca.`Bike #` = bc.`Bike #`").where("ab.`Bike #` = bc.`Bike #`")\
...   .where("a.id != b.id").where("b.id != c.id")\
...   .where("abStart < bcStart").where("bcStart < caStart")\
...   .orderBy(expr("cast(caStart as long) - cast(abStart as long)"))\
...   .selectExpr("a.id", "b.id", "c.id", "ab.`Start Date`", "ca.`End Date`")\
...   .limit(1).show(1, False)

-> spark 3.0 버전부터는 date format이 MM/dd/yyyy에서 MM-dd-yyyy로 바뀜

-> 타임 스탬프를 이용하여, 특정 자전거를 대상으로 지점 a에서 b, c, 다시 a로 이동했던 가장 짧은 경로는 무엇인지 탐색

 

-> 왜인진 모르겠지만 결과가 안 나옴. (쿼리가 잘못됐나?)

 

30.4. 그래프 알고리즘

30.4.1. 페이지랭크

  • 웹사이트의 중요성을 대략 판단하기 위해 특정 웹 페이지가 다른 웹 페이지로부터 받는 링크 수와 품질을 계산
>>> from pyspark.sql.functions import desc
>>> ranks = stationGraph.pageRank(resetProbability=0.15, maxIter=10)
>>> ranks.vertices.orderBy(desc("pagerank")).select("id", "pagerank").show(10)

-> 페이지랭크를 이용하여 중요한 자전거 도착 지점 판별

-> 페이지랭크가 높을 수록 해당 지점으로 많은 사람들이 왕래한다는 것을 의미

 

30.4.2. In-Degree와 Out-Degree 지표

  • 각 지점의 출입을 측정하기 위해 사용하는 지표

>>> inDeg = stationGraph.inDegrees
>>> inDeg.orderBy(desc("inDegree")).show(5, False)
>>> outDeg = stationGraph.outDegrees
>>> outDeg.orderBy(desc("outDegree")).show(5, False)

 

 

>>> degreeRatio = inDeg.join(outDeg, "id")\
...   .selectExpr("id", "double(inDegree)/double(outDegree) as degreeRatio")
>>> degreeRatio.orderBy(desc("degreeRatio")).show(10, False)
>>> degreeRatio.orderBy("degreeRatio").show(10, False)

-> 비율이 높은 곳은 주로 여행이 끝나는 지점, 비율이 낮은 곳은 주로 여행이 시작되는 지점

 

30.4.3. 너비 우선 탐색

  • BFS를 이용하여 두 개의 노드를 연결하는 방법을 탐색
>>> stationGraph.bfs(fromExpr="id = 'Townsend at 7th'",
...   toExpr="id = 'Spear at Folsom'", maxPathLength=2).show(10)

 

30.4.4. 연결 요소

  • connected component: 자체적인 연결을 가지고 있지만 큰 그래프에는 연결되지 않은(방향성이 없는) 서브그래프

 

  • 로컬 시스템에서 이 알고리즘을 실행하기 위해, 데이터를 샘플링
  • 샘플을 사용하면 가비지 컬렉션 이슈와 같은 스파크 애플리케이션 충돌을 발생시키지 않고 결과를 얻을 수 있다.
>>> spark.sparkContext.setCheckpointDir("/Users/sangyeong_park/CE/KHUDA_5th/Data_Engineering/spark-3.5.1-bin-hadoop3/tmp/checkpoints")
>>> minGraph = GraphFrame(stationVertices, tripEdges.sample(False, 0.1))
>>> cc = minGraph.connectedComponents()
>>>
>>> cc.where("component != 0").show()

 

-> 같은 component값을 가진 지점들은 서로 연결되어있음, 다른 component를 가졌다는 것은 서로 분리되었음을 의미

 

30.4.5. 강한 연결 요소

  • strongly connected component: 방향성이 고려된 상태로 강하게 연결된 구성요소(내부의 모든 정점 쌍 사이에 경로가 존재하는 서브그래프)
>>> scc = minGraph.stronglyConnectedComponents(maxIter=3)
>>> scc.groupBy("component").count().show()

 

Chapter 31. 딥러닝

  • 아파치 스파크는 빅데이터와 병렬 컴퓨팅 시스템으로서의 강점을 가지고 있기 때문에 딥러닝을 사용하기에 적합한 프레임워크이다.

31.2. 스파크에서 딥러닝을 사용하는 방법

  • 스파크에서 딥러닝을 사용하는 세 가지 주요 방법은 추론특이 생성과 전이 학습모델 학습

추론

  • 딥러닝을 사용하는 가장 간단한 방법은 이미 학습된 모델을 스파크로 가져와서 대용량 데이터셋에 병렬로 적용하는 것
  • 파이스파크를 사용하면 맵 함수를 사용해서 텐서플로나 파이토치와 같은 프레임워크를 호출하여 분산 처리를 통한 추론을 할 수 있다.

특이 생성과 전이 학습

  • 특징 생성: 기존 모델을 결과 도출이 아닌 특징을 생성하는 데 사용하는 
  • 전이 학습: 대부분의 딥러닝 모델은 하위 계층에서 철저한 학습을 통해 유용한 특징 표현 학습(feature representation learning)을 하는데, 이런 특징들을 원래의 학습 데이터가 다루지 않았던 새로운 문제를 해결할 수 있는 새로운 모델을 학습하는 데 활용
  • 본래 분석하고자 하는 학습 데이터가 충분하지 않은 경우 유용

모델 학습

  • 스파크에서 신규 딥러닝 모델 개발하는 방법은 두 가지
  • 스파크 클러스터를 사용하여 단일 모델에 대한 학습을 여러 서버에서 병렬 처리하고, 각 서버 간의 통신을 통해 최종 결과를 업데이트
  • 특징 라이브러리를 사용하여 다양한 모델 객체를 병렬로 학습시키고 다양한 모델 아키텍처와 하이퍼파라미터를 검토하여 최종 모델 선택과 최적화 과정을 효율적으로 수행
  • 모델 학습 시 병렬 처리를 원치 않는 경우, 이런 라이브러리들을 사용하여 클러스터로부터 데이터를 추출하고, 텐서플로 같은 딥러닝 프레임워크에서 지원하는 데이터 포맷을 사용하여 단일 머신 기반의 모델 학습 스크립트로 내보낼 수 있다.

 

31.3. 딥러닝 라이브러리

31.3.1. MLlib에서 지원하는 신경망

  • 스파크의 MLlib는 ml.classification.MultilayerPerceptronClassifier 클래스의 다층 퍼셉트론 분류기와 같은 단일 심층 학습 알고리즘을 지원한다.
  • 이 클래스는 전이 학습을 할 때 분류 모델의 마지막 몇 개 계층을 학습하는 데 가장 유용하다.

31.3.2. 텐서프레임

  • TensorFrames: 스파크 DataFrame과 텐서플로 간에 데이터 송수신을 쉽게 하도록 도와주는 추론 및 전이 학습 지향 라이브러리
  • 텐서플로에서 스파크로 데이터를 전달하기 위해 단순하지만 최적화된 인터페이스 제공하는 데 중점을 둔다.
  • TensorFrame을 사용하여 스파크 DataFrame에 모델을 적용하면 빠른 데이터 전송 및 초기 구동 비용에 대한 비용 상쇄 효과로 인해, 텐서플로 모델을 직접 불러오는 파이썬 맵 함수를 호출하는 것보다 일반적으로 더 효율적이다.

31.3.3. BigDL

  • 인텔에서 개발한 아파치 스파크의 분산 딥러닝 프레임워크. 추론을 사용하여 딥러닝 모델의 빠른 적용과 대용량 모델의 분산 학습(distributed training)을 지원한다.
  • GPU가 아니라 CPU 활용에 최적화되어, 아파치 하둡과 같은 기존의 CPU 기반 클러스터에서도 효율적으로 딥러닝을 구현할 수 있다.

31.3.4. TensorFlowOnSpark

  • 스파크 클러스터에서 텐서플로 모델을 병렬로 학습시키는 데 사용하는 대중적인 라이브러리
  • 스파크 잡 내에서 텐서플로의 분산 모드를 실행시키고, 스파크 RDD 혹은 DataFRame 데이터를 텐서플로 잡에 자동으로 공급한다.

DeepLearning4J

  • 단일 노드 및 분산 학습 옵션을 모두 제공하는 Java 및 Scala의 오픈소스이자 분산 딥러닝 프로젝트이다.
  • JVM용으로 설계되어, 파이썬을 개발 프로세스에 추가하지 않으려는 사용자 그룹에 더 높은 편의성을 제공

31.3.6. 딥러닝 파이프라인

  • 딥러닝 기능을 스파크의 ML 라이브러리인 API에 통합시킨 데이터브릭스의 오픈소스 패키지
  • 딥러닝 프레임워크를 표준 스파크 API(ML 파이프라인이나 스파크 SQL 등)에 통합하여 사용하기 쉽게 만들고, 모든 연산을 기본적으로 분산 처리하는 데에 목표를 둔다.

 

31.4. 딥러닝 파이프라인을 사용한 간단한 예제

>>> from sparkdl import readImages
>>> img_dir = '/Users/sangyeong_park/CE/KHUDA_5th/Data_Engineering/Spark-The-Definitive-Guide-master/data/deep-learning-images/'
>>> image_df = readImages(img_dir)
>>> image_df.printSchema()

-> 플라워 데이터셋의 메타데이터

31.4.3. 전이 학습

>>> from sparkdl import readImages
>>> from pyspark.sql.functions import lit
>>> tulips_df = readImages(img_dir + "/tulips").withColumn("label", lit(1))
>>> daisy_df = readImages(img_dir + "/daisy").withColumn("label", lit(0))
>>> tulips_train, tulips_test = tulips_df.randomSplit([0.6, 0.4])
>>> daisy_train, daisy_test = daisy_df.randomSplit([0.6, 0.4])
>>> train_df = tulips_train.unionAll(daisy_train)
>>> test_df = tulips_test.unionAll(daisy_test)

-> 유형별 꽃에 대한 데이터를 로드하고, 학습셋과 테스트셋 생성

 

from pyspark.ml.classification import LogisticRegression
from pyspark.ml import Pipeline
from sparkdl import DeepImageFeaturizer
featurizer = DeepImageFeaturizer(inputCol="image", outputCol="features",
  modelName="InceptionV3")
lr = LogisticRegression(maxIter=1, regParam=0.05, elasticNetParam=0.3,
  labelCol="label")
p = Pipeline(stages=[featurizer, lr])
p_model = p.fit(train_df)

 

-> 이미지 패턴을 식별하는 데 사용되는 사전학습된 모델 Inception을 활용

-> 다양하고 보편적인 사물 및 동물의 이미지를 인식하는 데 적합하도록 사전학습됨

-> 전이 학습을 통해 다양한 꽃을 구별하도록 수정

(tf 버전이 맞지 않아 여기서부터 실습 못함)

from pyspark.ml.evaluation import MulticlassClassificationEvaluator
tested_df = p_model.transform(test_df)
evaluator = MulticlassClassificationEvaluator(metricName="accuracy")
print("Test set accuracy = " + str(evaluator.evaluate(tested_df.select(
  "prediction", "label"))))

-> 모델 학습을 완료하고, 분류 평가기를 사용

 

from pyspark.sql.types import DoubleType
from pyspark.sql.functions import expr
# a simple UDF to convert the value to a double
def _p1(v):
  return float(v.array[1])
p1 = udf(_p1, DoubleType())
df = tested_df.withColumn("p_1", p1(tested_df.probability))
wrong_df = df.orderBy(expr("abs(p_1 - label)"), ascending=False)
wrong_df.select("filePath", "p_1", "label").limit(10).show()

-> 예제의 Dataframe을 사용하여 예측이 잘못된 로우와 이미지를 검사

 


참고자료

  • "스파크 완벽 가이드" , 빌 체임버스 , 마테이 자하리아 저자(글) · 우성한 , 이영호 , 강재원 번역