beijing_taxi_2012 to pgdb

import psycopg2
import pandas as pd
from io import StringIO
import csv
from itertools import islice
import os
import numpy as np
from datetime import datetime

def if_contain_symbol(keyword):
    symbols = "?>>~!@#$%^&*()_+-*/<>,.[]\/"
    for symbol in symbols:
        if symbol in keyword:
            return True
    else:
        return False

def table_exist(table_name=None, conn=None, cur=None):
    try:
        cur.execute("select to_regclass(" + "\'" + table_name + "\'" + ") is not null")
        rows = cur.fetchall()
    except Exception as e:
        rows = []
        conn.close()
    if rows:
        data = rows
        flag = data[0][0]
        return flag

def txt2csv(txt_dir, csv_save_dir):
    #  TXT ----> CSV
    txt_files = os.listdir(txt_dir)
    for txt in txt_files:
        # print(txt)
        txt_path = os.path.join(txt_dir, txt)
        ts_m = txt.split("_")[-1].split(".")[0]
        ts_m = datetime.strptime(ts_m, '%Y%m%d%H%M%S').strftime('%Y-%m-%d %H:%M:%S')
        csv_name = txt.split(".")[0] + ".txt"
        csv_file = os.path.join(csv_save_dir, csv_name)
        with open(txt_path, 'r') as read_file:
            reader = csv.reader(read_file)
            Trjs = []
            for row in islice(reader, 1, None):  # 跳过第一行
                row[1] = row[1].split("$")[-1]
                row[-1] = row[-1].split("#")[0]
                flag = if_contain_symbol(row[2])  # 乱码判断
                trj_list = []
                if flag==False:
                    # print(row)
                    serial_number = row[0]
                    code_company = row[1]
                    unit_id = row[2]
                    ts_m = ts_m
                    ts_s = row[3]
                    ts_s = datetime.strptime(ts_s, '%Y%m%d%H%M%S').strftime('%Y-%m-%d %H:%M:%S')
                    lon = row[4]
                    lat = row[5]
                    speed = row[8]
                    direction = row[9]
                    state = row[10]
                    event = row[11]
                    trj_list.append(serial_number)
                    trj_list.append(code_company)
                    trj_list.append(unit_id)
                    trj_list.append(ts_m)
                    trj_list.append(ts_s)
                    trj_list.append(lon)
                    trj_list.append(lat)
                    trj_list.append(speed)
                    trj_list.append(direction)
                    trj_list.append(state)
                    trj_list.append(event)

                    Trjs.append(trj_list)
            # 列表
            Trjs_name_list = ['serial_number', 'code_company', 'unit_id', 'ts_m',
                              'ts_s', 'lon', 'lat', 'speed', 'direction', 'state', 'event']
            # list转dataframe
            df = pd.DataFrame(Trjs, columns=Trjs_name_list)
            # 保存到
            df.to_csv(csv_file, header=False, index=False, encoding="utf-8")
    print("txt2csv, down!!!")

def filterByinteral(time_interal, csv_save_dir):
    ymd = time_interal[0].split(" ")[0].split("-")
    ymd = ymd[0] + ymd[1] + ymd[2]
    hms = time_interal[0].split(" ")[-1].split(":")
    hms = hms[0] + hms[1] + hms[2]
    time_interal[0] = ymd + hms

    ymd = time_interal[1].split(" ")[0].split("-")
    ymd = ymd[0] + ymd[1] + ymd[2]
    hms = time_interal[1].split(" ")[-1].split(":")
    hms = hms[0] + hms[1] + hms[2]
    time_interal[1] = ymd + hms

    csv_files = os.listdir(csv_save_dir)
    files_list = []
    for csv in csv_files:
        # print(csv)
        tm_m = float(csv.split("_")[-1].split(".")[0])
        tm_min, tm_max = float(time_interal[0]), float(time_interal[1])
        if tm_m >= tm_min and tm_m <= tm_max:
            files_list.append(csv)
    return files_list

csv_save_dir = r'D:/DataWorkspace/data/20121024_csv'
txt_dir = r'D:/DataWorkspace/data/20121024'
# TXT --> CSV # # # #
txt2csv(txt_dir, csv_save_dir)

# connection the database
conn = psycopg2.connect(database="beijing", user="jiangshan", password="jiangshan", host="localhost", port="5432")
cur = conn.cursor()

# table_name
table_name = "taxi2012_bj"
the_geom_SRID = "4326"

# 时空查询点
point = (116.306251, 39.98070)
point_r = 1000.5#

# 时间片区间限制 该时段范围内的轨迹点
time_interal = ['2012-10-24 10:00:00', '2012-10-24 13:00:00']

# CREATE TABLE IF table IS NOT EXIST
# 查询出来的表是否存在的状态,存在则为True,不存在则为False
table_flg = table_exist(table_name, conn, cur)
if table_flg is False:
    sql = "DROP TABLE public.{0} CASCADE".format(table_name)# -- 删除表
    sql = "CREATE TABLE IF NOT EXISTS {0} (serial_number BIGINT, code_company TEXT, unit_id BIGINT, ts_m TIMESTAMP, ts_s TIMESTAMP, lon DOUBLE PRECISION, lat DOUBLE PRECISION, speed FLOAT, direction FLOAT, state INT, event INT)".format(table_name)
    cur.execute(sql)
    conn.commit()
files_list = filterByinteral(time_interal, csv_save_dir)
# 插入数据
print('IMPORT FILES......')# copy_from 不支持 GEOMETRY对象批量导入
header_name_list = ['serial_number', 'code_company', 'unit_id', 'ts_m', 'ts_s', 'lon', 'lat', 'speed', 'direction', 'state', 'event']
dtype_dic = {'serial_number': object, 'code_company': object, 'unit_id': object, 'ts_m': object, 'ts_s': object, 'lon': object, 'lat': object, 'speed': float, 'direction': float, 'state': object, 'event': object}
id_list = []
for csv in files_list:
    print(csv)
    csv_path = os.path.join(csv_save_dir, csv)
    data = pd.read_csv(csv_path, header=None, names=header_name_list, dtype=dtype_dic)

    u_id = np.unique(data.unit_id.values).tolist()
    id_list += u_id

    # dataframe类型转换为IO缓冲区中的str类型
    output = StringIO()
    data.to_csv(output, sep='\t', index=False, header=False)
    output = output.getvalue()
    # print(output)
    cur.copy_from(StringIO(output), table_name)
    conn.commit()
id_list = list(set(id_list))
print('IMPORT FILES OK!!')# copy_from 不支持 GEOMETRY对象批量导入

print('ADD A GEOMETRY COLUMN......')
# ADD A GEOMETRY COLUMN
cur.execute("alter table " + table_name + " add the_geom GEOMETRY")
conn.commit()

print('UPDATE THE GEOMETRY.....')
# UPDATE THE GEOMETRY
# for id in id_list:
#     sql = "UPDATE " + table_name + " set the_geom=st_geomfromtext(\'POINT(\'|| lon ||' '|| lat ||\')\',\'"+the_geom_SRID+"\') where unit_id = {0}".format(id)
#     sql = "UPDATE public.{0} SET the_geom=st_geomfromtext(\'POINT(\' || lon || \' \' || lat || \')\', {1})".format(table_name, the_geom_SRID)
#     cur.execute(sql)
#     conn.commit()
sql = "UPDATE public.{0} SET the_geom=st_geomfromtext(\'POINT(\' || lon || \' \' || lat || \')\', {1})".format(table_name, the_geom_SRID)
cur.execute(sql)
conn.commit()

cur.close()
conn.close()
print('done')

 

posted @ 2021-05-21 20:04  土博姜山山  阅读(60)  评论(0编辑  收藏  举报