《大数据+AI在大健康领域中最佳实践前瞻》---- 基于DBSCAN 与软聚类实现单一实体识别

本文探讨了如何利用DBSCAN算法进行单一实体识别,并结合软聚类方法对多个个体中的同一实体进行进一步拆分。文章介绍了年龄标准化步骤,以及在Python中实现DBSCAN和软聚类的过程,包括关键函数和Spark SQL的整合。最后展示了实体统一的完整流程,从数据预处理到最终结果生成。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >


DBSCAN 与软聚类实现单一实体识别,可以用于 多个不同个体中的同一个体识别。

使用到的开源库


import os
import json
import math
import numbers
import numpy as np
 
import itertools as it
import operator as op
 
import pandas as pd
pd.set_option('display.max_columns', None)
import cufflinks as cf
cf.go_offline()
 
from pyspark.sql import SparkSession
from pyspark import SparkConf
from pyspark.sql.types import *
from pyspark.sql.functions import *
from pyspark.sql import functions as F
from pyspark.storagelevel import StorageLevel



年龄标准化

#Check if the value x is float.
def check_float(x):
    try:
        float(x)
        return True
    except ValueError:
        return False


 
#STandardize the age to Years based on AGEUNIT
def standardize_age(age, ageunit):
    age = str(age)
    if check_float(age):
        age = float(age)
    else :
        age = 0.0
 
    ageunit = str(ageunit)
     
    if ageunit == '月':  #month
        age = age/12.0
    elif ageunit == '周':  #week
        age = age/52.0
    elif ageunit == '日':   #days
        age = age/365.0
    else:
        age = age
         
    if age < 1 and age > 0:
        age = int(1)
    else :
        age = int(age)
 
    return age
standardize_age = udf(standardize_age, IntegerType())


DBSCAN

A Density-Based Algorithm for Discovering Clusters in Large Spatial Databases with Noise
Martin Ester, Hans-Peter Kriegel, Jörg Sander, Xiaowei Xu
dbscan: density based spatial clustering of applications with noise

Density-based spatial clustering of applications with noise (DBSCAN) is a data clustering algorithm proposed by Martin Ester, Hans-Peter Kriegel, Jörg Sander and Xiaowei Xu in 1996.[1] It is a density-based clustering algorithm: given a set of points in some space, it groups together points that are closely packed together (points with many nearby neighbors), marking as outliers points that lie alone in low-density regions (whose nearest neighbors are too far away). DBSCAN is one of the most common clustering algorithms and also most cited in scientific literature

The DBSCAN algorithm can be abstracted into the following steps:[5]
Find the ε (eps) neighbors of every point, and identify the core points with more than minPts neighbors.
Find the connected components of core points on the neighbor graph, ignoring all non-core points.
Assign each non-core point to a nearby cluster if the cluster is an ε (eps) neighbor, otherwise assign it to noise.

REFERRED FROM
https://github.com/choffstein/dbscan/blob/master/dbscan/dbscan.py



# -*- coding: utf-8 -*-
 

DBSCAN(DB, distFunc, eps, minPts) {
   C = 0                                                  /* Cluster counter */
   for each point P in database DB {
      if label(P) ≠ undefined then continue               /* Previously processed in inner loop */
      Neighbors N = RangeQuery(DB, distFunc, P, eps)      /* Find neighbors */
      if |N| < minPts then {                              /* Density check */
         label(P) = Noise                                 /* Label as Noise */
         continue
      }
      C = C + 1                                           /* next cluster label */
      label(P) = C                                        /* Label initial point */
      Seed set S = N \ {P}                                /* Neighbors to expand */
      for each point Q in S {                             /* Process every seed point */
         if label(Q) = Noise then label(Q) = C            /* Change Noise to border point */
         if label(Q) ≠ undefined then continue            /* Previously processed */
         label(Q) = C                                     /* Label neighbor */
         Neighbors N = RangeQuery(DB, distFunc, Q, eps)   /* Find neighbors */
         if |N| ≥ minPts then {                           /* Density check */
            S = S ∪ N                                     /* Add new neighbors to seed set */
         }
      }
   }
}

RangeQuery(DB, distFunc, Q, eps) {
   Neighbors = empty list
   for each point P in database DB {                      /* Scan all points in the database */
      if distFunc(Q, P) ≤ eps then {                      /* Compute distance and check epsilon */
         Neighbors = Neighbors ∪ {P}                      /* Add to result */
      }
   }
   return Neighbors
}
"""


UNCLASSIFIED = False
NOISE = None
 
def _dist(p,q):
    return math.sqrt(np.power(p-q,2).sum())
 
def _eps_neighborhood(p,q,eps):
    return _dist(p,q) < eps
 
def _region_query(m, point_id, eps):
    n_points = m.shape[1]
    seeds = []
    for i in range(0, n_points):
        if _eps_neighborhood(m[:,point_id], m[:,i], eps):
            seeds.append(i)
    return seeds
 
def _expand_cluster(m, classifications, point_id, cluster_id, eps, min_points):
    seeds = _region_query(m, point_id, eps)
    if len(seeds) < min_points:
        classifications[point_id] = NOISE
        return False
    else:
        classifications[point_id] = cluster_id
        for seed_id in seeds:
            classifications[seed_id] = cluster_id
             
        while len(seeds) > 0:
            current_point = seeds[0]
            results = _region_query(m, current_point, eps)
            if len(results) >= min_points:
                for i in range(0, len(results)):
                    result_point = results[i]
                    if classifications[result_point] == UNCLASSIFIED or \
                       classifications[result_point] == NOISE:
                        if classifications[result_point] == UNCLASSIFIED:
                            seeds.append(result_point)
                        classifications[result_point] = cluster_id
            seeds = seeds[1:]
        return True
 



"""Implementation of Density Based Spatial Clustering of Applications with Noise
See https://en.wikipedia.org/wiki/DBSCAN
 
scikit-learn probably has a better implementation
 
Uses Euclidean Distance as the measure#test_dbscan()
 


Inputs:
m - A matrix whose columns are feature vectors
eps - Maximum distance two points can be to be regionally related
min_points - The minimum number of points to make a cluster
 
Outputs:
An array with either a cluster id number or dbscan.NOISE (None) for each
column vector in m.
"""
#DBSCAN for a single dimensional vector    	
def dbscan(m, eps, min_points):
    m = np.array(m)
    cluster_id = 1
    n_points = m.shape[1]
    classifications = [UNCLASSIFIED] * n_points
    for point_id in range(0, n_points):
        point = m[:,point_id]
        if classifications[point_id] == UNCLASSIFIED:
            if _expand_cluster(m, classifications, point_id, cluster_id, eps, min_points):
                cluster_id = cluster_id + 1
    return classifications
 


#DBSCAN for multi-dimensional vector
def dbscan_mult(m, eps, min_points):
    m = np.array(m)
    m = np.transpose(m)
    cluster_id = 1
    n_points = m.shape[1]
    classifications = [UNCLASSIFIED] * n_points
    for point_id in range(0, n_points):
        point = m[:,point_id]
        if classifications[point_id] == UNCLASSIFIED:
            if _expand_cluster(m, classifications, point_id, cluster_id, eps, min_points):
                cluster_id = cluster_id + 1
    return classifications
 
def test_dbscan():
    m = np.matrix('1 1.2 0.8 3.7 3.9 3.6 10')
    eps = 0.5
    min_points = 2
    print(m)
    assert dbscan(m, eps, min_points) == [1, 1, 1, 2, 2, 2, None



SOFT-CLUSTERING

This code takes in a list of cols and their values
EG :

RECORD CLUSTER #PI_AGE CLUSTER Initially we have a single cluster
Input : [[[record1, record2, record3, record4]], [[32, 33, 57, 31]] ]

Output : [ [[record1, record2, record4], [record3]], [[32, 33 31], [57]] ]

The clusters formed by the Hard Clustering approach are further broken down with the Soft Clustering approach. This is essentially a DBSCAN clustering performed separately on each of the individual Hard clusters. The DBSCAN is performed on the features specified for Soft Constraints with the Soft Parameters. The Soft clustering however can be performed in two ways.

Iterative Soft-Clustering:

For each Soft-Constraint/ feature specified we perform the clustering iteratively breaking down the cluster into smaller pieces with each iteration. The approach applied when two or more of the constraints are very different with regard to their distance metrics and a clustering of the whole vector does not make sense.

Combined Feature Soft-Clustering:

All the specified Soft Constraints are vectorized and the DBSCAN is performed on these vectors and their respective vector distances. This is possible when all the constraints specified have a homogeneous distance metric.
The flow-charts for both the methods can be seen in the Flow Diagram section.


def ER_soft_cluster(arr_list, params = [(2.0, 1), (2.0, 1)]):
    #if the list has more than one element. If there's only one element no need of clustering return the input

    if len(list(arr_list[0][0])) > 1:
        #current length of clusters increases with every soft_cols processed. Clusters are broken into smaller parts
        curr_clusters = len(arr_list[0])
        #for every column specified in soft_cols. We need to perform a DBSCAN on all of these iteratively
        for i in range(1, len(arr_list)):
            #Temporary list which holds the clusters and it's values
            t_arr_list = [[] for i in range(len(arr_list))]
            #For every sub-cluster that has been broken down by the previous columns DBSCAN
            for j in range(curr_clusters):
                #Get DBSCAN clustering labels
                val_arr = [[float(i) for i in list(arr_list[i][j])]]
                labels = dbscan(val_arr, params[i-1][0], params[i-1][1])
                #Break all column's clusters based on DBSCAN labels
                for k in range(len(arr_list)):
                    t_arr_list[k] += list([[x[1] for x in v] for k,v in it.groupby(sorted(zip(labels, arr_list[k][j]), key=op.itemgetter(0)), key=op.itemgetter(0))])
            #Changed the list being considered to the broken down list
            arr_list = list(t_arr_list)
            #change the number of clusters
            curr_clusters  = len(arr_list[0])
        #return the Record ID clusters	
        return arr_list[0]
        #returning the input Record ID cluster if there is only one element. Cannot break it down further
    else:
        return arr_list[0]
     
#Getting a UDF which takes in the DBSCAN parameters as argument
def ER_soft_cluster_udf(params):
    return udf(lambda l: ER_soft_cluster(l, params), ArrayType(ArrayType(StringType())))
 
#Function to combine the different soft_cols in a single list
def convert_list(*args):
    arg_list = list(args)
    return [[arg] for arg in arg_list ]
convert_list = udf( convert_list,ArrayType(ArrayType(ArrayType(StringType()))) )
 
def get_birth_year(date_str, age):
    date_str = str(date_str)
    age = int(age)
    year= int(date_str.split(' ')[0].split('/')[0])
    return year - age
get_birth_year = udf(get_birth_year, IntegerType())

实体统一

HARD CLUSTERING
REC_ID PI_NAME PI_FROM PI_SEX PI_AGE ENTITY ID
1 abc xyz M 25 1
2 abc xyz M 24 1
3 abc xyz M 12 1
4 lmn xyz M 32 2

ACTUAL OPERATION
REC_LIST AGE_LIST
[1, 2, 3] [25, 24, 12]
[4] [32]

SOFT CLUSTERING : DBSCAN
REC_ID PI_NAME PI_FROM PI_SEX PI_AGE ENTITY ID
1 abc xyz M 25 1.1
2 abc xyz M 24 1.1
3 abc xyz M 12 1.2
4 lmn xyz M 32 2.1

ACTUAL OPERATION
AFTER CLUSTERING
REC_LIST AGE_LIST
[[1, 2], [3]] [[25, 24], [12]]
[[4]] [[32]]

EXPLODE
REC_LIST AGE_LIST
[1, 2] [25, 24]
[3] [12]
[4] [32]

ASSIGN ID
REC_LIST AGE_LIST PI_ID
[1, 2] [25, 24] 1
[3] [12] 2
[4] [32] 3

EXPLODE
REC_ID AGE PI_ID
1 25 1
2 24 1
3 12 2
4 32 3

JOIN : BY MATCHING REC_IDs OF LATEST DATAFRAME WITH ORIGINAL DATAFRAME
REC_ID PI_NAME PI_FROM PI_SEX PI_AGE PI_ID FINAL
1 abc xyz M 25 1 1.2
2 abc xyz M 24 1 10
3 abc xyz M 12 2
4 lmn xyz M 32 3

GETTING PI_ID GIVEN PARAMS
SELECT THOSE ROWS WHICH MATCH THE GIVEN PARAMS RETURN PI_ID OF FIRST ROW
EG : GIVEN PARAMS : PI_NAME : abc, PI_FROM : xyz, PI_SEX : M, PI_AGE : 25

SELECTED ROWS
REC_ID PI_NAME PI_FROM PI_SEX PI_AGE PI_ID
1 abc xyz M 25 1
RETURN PI_ID : 1


实体统一实现


#Function to check if value exists in array
def contains(array, value):
    array = list(array)
    value = int(value)
    return value in array
contains = udf(contains, BooleanType())

#Clusters Age array into groups of ages of the form AGE-1 , AGE, AGE+1


def cluster_ages(age_list):
    age_list.sort()
    age_group = []
    tmp_gp= []
    
    #iterate through AGE group EG: [2, 3, 4, 6, 7, 8, 9]
    for a in age_list:
        #append first element to temporary array
        if len(tmp_gp) == 0:
            tmp_gp.append(a)
        #append subsequent elements if the current element has difference of 1 EG : tmp_gp = [2] <--3 
        else:
            if a == tmp_gp[-1] +1 :
                tmp_gp.append(a)
            else:
                #if not append temporary array to final array, EG: age_group = [[2, 3, 4]], tmp_gp = [6]
                age_group.append(tmp_gp)
                tmp_gp = [a]
    age_group.append(tmp_gp)
    #EG : age_group = [[2,3,4],[6,7,8,9]]
    age_group_new = []
    
    #Break bigger continuous groups into groups of 3, EG : [[2,3,4],[6,7,8,9]] --> [[2,3,4],[6,7,8],[7,8,9]]
    for lst in age_group:
        if len(lst) > 3:
            for i in range(1, len(lst) -1):
                age_group_new.append([lst[i-1], lst[i], lst[i+1]])
        else :
            age_group_new.append(lst)
    return age_group_new
cluster_ages = udf(cluster_ages, ArrayType(ArrayType(IntegerType())))

#Choose the first PI_ID in the group of PI_IDs to the record. They are essentially the same ID which faal in the same age range
def assign_PID(array):
    return array[0]
assign_PID = udf(assign_PID, StringType())

def do_ER(input_sdf):
    #ASSIGNING UNIQUE PIDs
    #input_sdf = input_sdf.withColumn('PI_AGE', standardize_age('PI_AGE', 'PI_AGEUNIT'))
    #Group ages into a list
    age_gp = input_sdf.groupBy('PI_FROM', 'PI_NAME', 'PI_SEX').agg(collect_set('PI_AGE').alias('AGE_GROUP'))
    #cluster ages
    age_clust = age_gp.withColumn('AGE_CLUST', cluster_ages('AGE_GROUP'))
    #explode the clusters into unique rows
    clust_expld = age_clust.select('PI_FROM', 'PI_NAME', 'PI_SEX', explode('AGE_CLUST').alias('AGE_GROUP'))
    #Assign a PID for each row
    pid_sdf = clust_expld.withColumn('PI_ID', monotonically_increasing_id())
    pid_sdf = pid_sdf.toDF('T_FROM', 'T_NAME', 'T_SEX', 'AGE_GROUP', 'PI_ID')
    #join with the original SDF
    cols = [input_sdf.PI_NAME, input_sdf.PI_FROM, input_sdf.PI_AGE, input_sdf.PI_SEX, pid_sdf.PI_ID,]
    join_sdf = input_sdf.join(pid_sdf).where((input_sdf.PI_FROM == pid_sdf.T_FROM)&(input_sdf.PI_NAME == pid_sdf.T_NAME)&(input_sdf.PI_SEX == pid_sdf.T_SEX)&contains(pid_sdf.AGE_GROUP, input_sdf.PI_AGE)).select(cols)
    
    #ELIMINATE EXTRA/DUPLICATE ROWS
    #collect those records which have mulitple PIDs
    unique_sdf = join_sdf.groupBy('ORIGREC').agg(collect_set('PI_ID').alias('PI_GROUP'))
    #pick one PID and assign it to the record
    unique_sdf = unique_sdf.withColumn('PI_ID', assign_PID('PI_GROUP'))
    unique_sdf = unique_sdf.select('ORIGREC', 'PI_ID')
    unique_sdf = unique_sdf.toDF('T_ORIGREC', 'T_ID')
    #Join it back with the SDF 
    cols = [join_sdf.PI_NAME, join_sdf.PI_FROM, join_sdf.PI_AGE, join_sdf.PI_SEX, join_sdf.PI_ID, join_sdf.ORIGREC, join_sdf.ORIGSTS,]    
    join_sdf = unique_sdf.join(join_sdf).where((unique_sdf.T_ID == join_sdf.PI_ID) & (unique_sdf.T_ORIGREC == join_sdf.ORIGREC)).select(cols)
    return join_sdf
#ER on the SDF 
ER_sdf = do_ER(AK_sdf)
ER BIRTH_YEAR替代PI_AGE

def get_birth_year(date_str, age):
	date_str = str(date_str)
	age = int(age)
	year= int(date_str.split(' ')[0].split('/')[0])
	return year - age
get_birth_year = udf(get_birth_year, IntegerType())

	def compute_ER_sdf2(self):
    	#init whatever values passed in the parameter dictionary
   	 
    	#Columns on which Hard Clustering is performed = [PI_FROM, PI_NAME, PI_SEX]
    	hard_cols = self.hard_cols  
    	#Columns on which Soft Clustering is performed on = [PI_AGE]
    	soft_cols = ['BIRTH_YEAR']
    	#Primary key used to ID the records = ORIGREC
    	prim_key = self.prim_key    
    	#Parameters passed to the DBSCAN code. (Min Neighbours, Min distance/Epsilon) = [(2.0, 1)]
    	soft_params = self.soft_params  

    	#Select all the concerned columns from the original dataframe
    	ip_df = self.orig_sdf.withColumn('BIRTH_YEAR', get_birth_year('REC_CREATEDDATE', 'PI_AGE'))
    	ip_df = ip_df.select(hard_cols + soft_cols + ['ORIGREC']).persist()
   	 
    	#Prepare a list of collect_list object for all columns in soft_cols
    	soft_collect = [collect_list(prim_key).alias('REC_LIST')] + [collect_list(i).alias(i + '_LIST') for i in soft_cols]
    	#GroupBy the hard cols and collect the soft cols
    	hard_er_df = ip_df.groupBy(hard_cols).agg(*soft_collect)
   	 
    	#List of cols to be grouped so that they can be sent to the soft clustering module
    	group_cols = ['REC_LIST'] + [col+'_LIST' for col in soft_cols]
    	#convert_list groups all the columns and gives out a single list to be passed to soft clustering module
    	soft_df = hard_er_df.withColumn('S_CLUST_VEC', convert_list(*group_cols)).select('S_CLUST_VEC')
    	#perform soft clustering. Output is a clustered list of records based on the soft_cols values EG : ip=[123, 233, 121, 121],  op=[[123], [233, 121], [121]]
    	soft_df = soft_df.withColumn('S_CLUST_VEC',ER_soft_cluster_udf(soft_params)('S_CLUST_VEC'))
    	#Explode the list of soft clusters. So each row no occupies a soft cluster
    	soft_df = soft_df.select(explode('S_CLUST_VEC').alias('REC_LIST'))
    	#assign PI_ID to each soft cluster. This will be our final Patient ID
    	soft_df = soft_df.withColumn('PI_ID', monotonically_increasing_id())
   	 
    	#Explode each cluster with the PI_IDs into indivual records
    	self.ER_sdf2 = soft_df.select('PI_ID', explode('REC_LIST').alias('RECORD_ID'))
   	 
    	#List of cols to get from the original Dataframe.
    	join_cols = [self.orig_sdf.PI_NAME, self.orig_sdf.PI_FROM, self.orig_sdf.PI_AGE, self.orig_sdf.PI_SEX, self.ER_sdf2.PI_ID, self.orig_sdf.]
   	 
    	#join the DF with PI_ID to the Original DF so that all records now contain a PI_ID.
    	self.ER_sdf2 = self.ER_sdf2.join(self.orig_sdf, self.orig_sdf.ORIGREC == self.ER_sdf2.RECORD_ID, how = 'inner').select(join_cols)
    	#Doing this to optimize the get_id function. Can probably be further optimized or modified based on the input parameters being supplied.
    	#For now I am assuming that the matching is done with respect to 'PI_FROM', 'PI_NAME', 'PI_AGE', 'PI_SEX' OR with respect to a Record ID
    	self.ER_sdf2 = self.ER_sdf2.select('PI_FROM', 'PI_NAME', 'PI_AGE', 'PI_SEX', 'PI_ID', 'ORIGREC').persist()






-----


# ENTITY RESOLUTION CLASS (Has been optimized to best of my knowledge)

```python

class entity_resolution():
    #Init the class values from input parameter dictionary
    def __init__(self, params):
        self.soft_cols = params['soft_cols']
        self.hard_cols = params['hard_cols']
        self.soft_params = params['soft_params']
        self.prim_key = params['prim_key']
        #self.soft_mode = params['soft_mode']
        self.spark = ''
        self.orig_sdf = ''
        self.ER_sdf = ''
        self.ER_sdf2 = ''
        #Values That could be passed through parameter dictionary
        self.os_environ_path = ''
        self.AWS_ACCESS_KEY = ''
        self.AWS_SECRET_KEY = ''
        self.AWS_S3_ENDPOINT = ''
        self.spark_jar_path = ''
        self.csv_path = ''
         
    #Setup Spark Session
    def setup_spark_session(self):
        os.environ["PYSPARK_PYTHON"] = "/home/hadoop/anaconda/envs/playground_py36/bin/python" #self.os_environ_path
        try:
            spark.stop()
            print("Stopped a SparkSession")
        except Exception as e:
            print("No existing SparkSession")
 
        SPARK_DRIVER_MEMORY= "10G"
        SPARK_DRIVER_CORE = "5"
        SPARK_EXECUTOR_MEMORY= "3G"
        SPARK_EXECUTOR_CORE = "1"
        AWS_ACCESS_KEY = "AKIAPEROGTP7BCLPGGDA"  #self.AWS_ACCESS_KEY
        AWS_SECRET_KEY = "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" #self.AWS_SECRET_KEY
        AWS_S3_ENDPOINT = "s3.cn-north-1.amazonaws.com.cn" #self.AWS_S3_ENDPOINT
 
 
        conf = SparkConf()\
                .setAppName("dian-ER")\
                .setMaster('yarn-client')\
                .set('spark.jars', 's3://engineering.insightzen.com/bins/crunch-core-0.15.0.jar')\
                .set('spark.executor.cores', SPARK_EXECUTOR_CORE)\
                .set('spark.executor.memory', SPARK_EXECUTOR_MEMORY)\
                .set('spark.driver.cores', SPARK_DRIVER_CORE)\
                .set('spark.driver.memory', SPARK_DRIVER_MEMORY)\
                .set('spark.driver.maxResultSize', '0')
 
        spark = SparkSession.builder.\
            config(conf=conf).\
            getOrCreate()
 
 
        sc=spark.sparkContext
        hadoop_conf = sc._jsc.hadoopConfiguration()
        hadoop_conf.set("fs.s3.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem")
        hadoop_conf.set("fs.s3a.access.key", AWS_ACCESS_KEY)
        hadoop_conf.set("fs.s3a.secret.key", AWS_SECRET_KEY)
        hadoop_conf.set("fs.s3a.endpoint", AWS_S3_ENDPOINT)
        hadoop_conf.set("mapreduce.fileoutputcommitter.algorithm.version", "2")
         
        #Init the ER Mdule's spark instance
        self.spark = spark
         
    #Load the CSV data into a spark dataframe and standardize age. Also select rows which don't have invalid PI_AGE	
    def load_data(self):
        self.orig_sdf = self.spark.read.option("header","true") \
                                      .option("multiLine", "true") \
                                      .csv("s3a://healthcrosscn.data.insightzen.com/dian/csv/检测结果*.csv")
                                      #.csv(self.csv_path + "/检测结果*.csv")
        self.orig_sdf = self.orig_sdf.withColumn('PI_AGE', standardize_age('PI_AGE', 'PI_AGEUNIT'))
        self.orig_sdf = self.orig_sdf.where(self.orig_sdf['PI_AGE'] != 0)
     
    #Do Entity Resolution. Meaning assign PI_ID to all rows. 
    def compute_ER_sdf(self):
        #init whatever values passed in the parameter dictionary
"""
The Hard clustering performs a strict match among all the records for the specified Hard constraints and clusters the records accordingly.
EG: Hard Clustering based on the Hard Constraints : PI_NAME, PI_FROM, PI_SEX

"""
         
        #Columns on which Hard Clustering is performed = [PI_FROM, PI_NAME, PI_SEX]
        hard_cols = self.hard_cols 
        #Columns on which Soft Clustering is performed on = [PI_AGE]
        soft_cols = self.soft_cols 
        #Primary key used to ID the records = ORIGREC
        prim_key = self.prim_key	
        #Parameters passed to the DBSCAN code. (Min Neighbours, Min distance/Epsilon) = [(2.0, 1)]
        soft_params = self.soft_params 
 
        #Select all the concerned columns from the original dataframe
        ip_df = self.orig_sdf.select(hard_cols + soft_cols + ['ORIGREC']).persist()
         
        #Prepare a list of collect_list object for all columns in soft_cols
        soft_collect = [collect_list(prim_key).alias('REC_LIST')] + [collect_list(i).alias(i + '_LIST') for i in soft_cols]
        #GroupBy the hard cols and collect the soft cols
        hard_er_df = ip_df.groupBy(hard_cols).agg(*soft_collect)
         
        #List of cols to be grouped so that they can be sent to the soft clustering module
        group_cols = ['REC_LIST'] + [col+'_LIST' for col in soft_cols]
        #convert_list groups all the columns and gives out a single list to be passed to soft clustering module
        soft_df = hard_er_df.withColumn('S_CLUST_VEC', convert_list(*group_cols)).select('S_CLUST_VEC')
        #perform soft clustering. Output is a clustered list of records based on the soft_cols values EG : ip=[123, 233, 121, 121],  op=[[123], [233, 121], [121]]
        soft_df = soft_df.withColumn('S_CLUST_VEC',ER_soft_cluster_udf(soft_params)('S_CLUST_VEC'))
        #Explode the list of soft clusters. So each row no occupies a soft cluster
        soft_df = soft_df.select(explode('S_CLUST_VEC').alias('REC_LIST'))
        #assign PI_ID to each soft cluster. This will be our final Patient ID
        soft_df = soft_df.withColumn('PI_ID', monotonically_increasing_id())
         
        #Explode each cluster with the PI_IDs into indivual records
        self.ER_sdf = soft_df.select('PI_ID', explode('REC_LIST').alias('RECORD_ID'))
         
        #List of cols to get from the original Dataframe.
        join_cols = [self.orig_sdf.PI_NAME, self.orig_sdf.PI_FROM, self.orig_sdf.PI_AGE, self.orig_sdf.PI_SEX, self.ER_sdf.PI_ID,]
         
        #join the DF with PI_ID to the Original DF so that all records now contain a PI_ID.
        self.ER_sdf = self.ER_sdf.join(self.orig_sdf, self.orig_sdf.ORIGREC == self.ER_sdf.RECORD_ID, how = 'inner').select(join_cols)
        #Doing this to optimize the get_id function. Can probably be further optimized or modified based on the input parameters being supplied.
        #For now I am assuming that the matching is done with respect to 'PI_FROM', 'PI_NAME', 'PI_AGE', 'PI_SEX' OR with respect to a Record ID
        self.ER_sdf = self.ER_sdf.select('PI_FROM', 'PI_NAME', 'PI_AGE', 'PI_SEX', 'PI_ID', 'ORIGREC').persist()
  
    def compute_ER_sdf2(self):
        #init whatever values passed in the parameter dictionary
         
        #Columns on which Hard Clustering is performed = [PI_FROM, PI_NAME, PI_SEX]
        hard_cols = self.hard_cols 
        #Columns on which Soft Clustering is performed on = [PI_AGE]
        soft_cols = ['BIRTH_YEAR']
        #Primary key used to ID the records = ORIGREC
        prim_key = self.prim_key	
        #Parameters passed to the DBSCAN code. (Min Neighbours, Min distance/Epsilon) = [(2.0, 1)]
        soft_params = self.soft_params 
 
        #Select all the concerned columns from the original dataframe
        ip_df = self.orig_sdf.withColumn('BIRTH_YEAR', get_birth_year('REC_CREATEDDATE', 'PI_AGE'))
        ip_df = ip_df.select(hard_cols + soft_cols + ['ORIGREC']).persist()
         
        #Prepare a list of collect_list object for all columns in soft_cols
        soft_collect = [collect_list(prim_key).alias('REC_LIST')] + [collect_list(i).alias(i + '_LIST') for i in soft_cols]
        #GroupBy the hard cols and collect the soft cols
        hard_er_df = ip_df.groupBy(hard_cols).agg(*soft_collect)
         
        #List of cols to be grouped so that they can be sent to the soft clustering module
        group_cols = ['REC_LIST'] + [col+'_LIST' for col in soft_cols]
        #convert_list groups all the columns and gives out a single list to be passed to soft clustering module
        soft_df = hard_er_df.withColumn('S_CLUST_VEC', convert_list(*group_cols)).select('S_CLUST_VEC')
        #perform soft clustering. Output is a clustered list of records based on the soft_cols values EG : ip=[123, 233, 121, 121],  op=[[123], [233, 121], [121]]
        soft_df = soft_df.withColumn('S_CLUST_VEC',ER_soft_cluster_udf(soft_params)('S_CLUST_VEC'))
        #Explode the list of soft clusters. So each row no occupies a soft cluster
        soft_df = soft_df.select(explode('S_CLUST_VEC').alias('REC_LIST'))
        #assign PI_ID to each soft cluster. This will be our final Patient ID
        soft_df = soft_df.withColumn('PI_ID', monotonically_increasing_id())
         
        #Explode each cluster with the PI_IDs into indivual records
        self.ER_sdf2 = soft_df.select('PI_ID', explode('REC_LIST').alias('RECORD_ID'))
         
        #List of cols to get from the original Dataframe.
        join_cols = [self.orig_sdf.PI_NAME, self.orig_sdf.PI_FROM, self.orig_sdf.PI_AGE, self.orig_sdf.PI_SEX, self.ER_sdf2.PI_ID, self.orig_sdf.ORIGREC, self.orig_sdf.ORIGSTS, self.orig_sdf.BARCODE,
                     
]
         
        #join the DF with PI_ID to the Original DF so that all records now contain a PI_ID.
        self.ER_sdf2 = self.ER_sdf2.join(self.orig_sdf, self.orig_sdf.ORIGREC == self.ER_sdf2.RECORD_ID, how = 'inner').select(join_cols)
        #Doing this to optimize the get_id function. Can probably be further optimized or modified based on the input parameters being supplied.
        #For now I am assuming that the matching is done with respect to 'PI_FROM', 'PI_NAME', 'PI_AGE', 'PI_SEX' OR with respect to a Record ID
        self.ER_sdf2 = self.ER_sdf2.select('PI_FROM', 'PI_NAME', 'PI_AGE', 'PI_SEX', 'PI_ID', 'ORIGREC').persist()
 #Function to get the PI_ID given input parameters
    #This Function is designed for DiAn's current specification of ER. Needs to be modified if we change the ER features used.
    #Function probably be optimized further to make it faster
def get_id(self, ip_params, match_params = True):
        #Supplied input parameters
        pi_from = ip_params['PI_FROM']
        pi_name = ip_params['PI_NAME']
        pi_age = ip_params['PI_AGE']
        pi_sex = ip_params['PI_SEX']
        #May or maynot be provided. If you want to do a direct match with record ID use `match_params = False` as an argument
        pi_record_id = ip_params['PI_RECORD_ID']
         
        #If you want to match all the parameters given.
        if match_params:
            #select all records that match the specification
            match_sdf = self.ER_sdf.where((self.ER_sdf['PI_FROM']== pi_from) & (self.ER_sdf['PI_NAME']== pi_name) & (self.ER_sdf['PI_SEX']== pi_sex) & (self.ER_sdf['PI_AGE']== pi_age)).select('PI_ID')
            #if records were found
            if match_sdf.count() != 0:
                #return the PI_ID of the first row
                pid = match_sdf.rdd.map(lambda x: x.PI_ID).first()
                return pid
            #If records not found return None
            else:
                return None
        #If you want to match only based on the Record ID(ORIGREC)
        else:
            match_sdf = self.ER_sdf.where(self.ER_sdf['ORIGREC']== pi_record_id).select('PI_ID')
            if match_sdf.count() != 0:
                pid = match_sdf.rdd.map(lambda x: x.PI_ID).first()
                return pid
            else:
                return None


测试

#Specify the parameters for your ER module

params = {'soft_cols' : ['PI_AGE'],
          'hard_cols' : ['PI_FROM', 'PI_NAME', 'PI_SEX'],
          'soft_params' : [(2.0, 1)],
          'prim_key' : 'ORIGREC',
         }

#Initialize class

er = entity_resolution(params)

#start spark session

er.setup_spark_session()

#load the CSV data

er.load_data()

#Perform ER

er.compute_ER_sdf()

#Now to get the PI_ID for any input values.

ip_param = {'PI_FROM': '19720000870825',
           'PI_NAME': '33800004',
           'PI_AGE': 87,
           'PI_SEX': '男',
           'PI_RECORD_ID': ''
          }
#The first call always takes sometime(About 16 Mins). 
#The calls are faster from the second call onwards
# because .persist() stores the ER_sdf in memory

%%time

PI_ID = er.get_id(ip_param)

#Prints None if records not found

print(PI_ID)



#Matching with Record ID(ORIGREC)

ip_param = {'PI_FROM': '',
            'PI_NAME': '',
            'PI_AGE': None,
            'PI_SEX': '',
            'PI_RECORD_ID': '208668745'
           }
%%time
PI_ID = er.get_id(ip_param, False)

print(PI_ID)


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

shiter

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值