beijing_taxi_2012 to pgdb
import psycopg2 import pandas as pd from io import StringIO import csv from itertools import islice import os import numpy as np from datetime import datetime def if_contain_symbol(keyword): symbols = "?>>~!@#$%^&*()_+-*/<>,.[]\/" for symbol in symbols: if symbol in keyword: return True else: return False def table_exist(table_name=None, conn=None, cur=None): try: cur.execute("select to_regclass(" + "\'" + table_name + "\'" + ") is not null") rows = cur.fetchall() except Exception as e: rows = [] conn.close() if rows: data = rows flag = data[0][0] return flag def txt2csv(txt_dir, csv_save_dir): # TXT ----> CSV txt_files = os.listdir(txt_dir) for txt in txt_files: # print(txt) txt_path = os.path.join(txt_dir, txt) ts_m = txt.split("_")[-1].split(".")[0] ts_m = datetime.strptime(ts_m, '%Y%m%d%H%M%S').strftime('%Y-%m-%d %H:%M:%S') csv_name = txt.split(".")[0] + ".txt" csv_file = os.path.join(csv_save_dir, csv_name) with open(txt_path, 'r') as read_file: reader = csv.reader(read_file) Trjs = [] for row in islice(reader, 1, None): # 跳过第一行 row[1] = row[1].split("$")[-1] row[-1] = row[-1].split("#")[0] flag = if_contain_symbol(row[2]) # 乱码判断 trj_list = [] if flag==False: # print(row) serial_number = row[0] code_company = row[1] unit_id = row[2] ts_m = ts_m ts_s = row[3] ts_s = datetime.strptime(ts_s, '%Y%m%d%H%M%S').strftime('%Y-%m-%d %H:%M:%S') lon = row[4] lat = row[5] speed = row[8] direction = row[9] state = row[10] event = row[11] trj_list.append(serial_number) trj_list.append(code_company) trj_list.append(unit_id) trj_list.append(ts_m) trj_list.append(ts_s) trj_list.append(lon) trj_list.append(lat) trj_list.append(speed) trj_list.append(direction) trj_list.append(state) trj_list.append(event) Trjs.append(trj_list) # 列表 Trjs_name_list = ['serial_number', 'code_company', 'unit_id', 'ts_m', 'ts_s', 'lon', 'lat', 'speed', 'direction', 'state', 'event'] # list转dataframe df = pd.DataFrame(Trjs, columns=Trjs_name_list) # 保存到 df.to_csv(csv_file, header=False, index=False, encoding="utf-8") print("txt2csv, down!!!") def filterByinteral(time_interal, csv_save_dir): ymd = time_interal[0].split(" ")[0].split("-") ymd = ymd[0] + ymd[1] + ymd[2] hms = time_interal[0].split(" ")[-1].split(":") hms = hms[0] + hms[1] + hms[2] time_interal[0] = ymd + hms ymd = time_interal[1].split(" ")[0].split("-") ymd = ymd[0] + ymd[1] + ymd[2] hms = time_interal[1].split(" ")[-1].split(":") hms = hms[0] + hms[1] + hms[2] time_interal[1] = ymd + hms csv_files = os.listdir(csv_save_dir) files_list = [] for csv in csv_files: # print(csv) tm_m = float(csv.split("_")[-1].split(".")[0]) tm_min, tm_max = float(time_interal[0]), float(time_interal[1]) if tm_m >= tm_min and tm_m <= tm_max: files_list.append(csv) return files_list csv_save_dir = r'D:/DataWorkspace/data/20121024_csv' txt_dir = r'D:/DataWorkspace/data/20121024' # TXT --> CSV # # # # txt2csv(txt_dir, csv_save_dir) # connection the database conn = psycopg2.connect(database="beijing", user="jiangshan", password="jiangshan", host="localhost", port="5432") cur = conn.cursor() # table_name table_name = "taxi2012_bj" the_geom_SRID = "4326" # 时空查询点 point = (116.306251, 39.98070) point_r = 1000.5#米 # 时间片区间限制 该时段范围内的轨迹点 time_interal = ['2012-10-24 10:00:00', '2012-10-24 13:00:00'] # CREATE TABLE IF table IS NOT EXIST # 查询出来的表是否存在的状态,存在则为True,不存在则为False table_flg = table_exist(table_name, conn, cur) if table_flg is False: sql = "DROP TABLE public.{0} CASCADE".format(table_name)# -- 删除表 sql = "CREATE TABLE IF NOT EXISTS {0} (serial_number BIGINT, code_company TEXT, unit_id BIGINT, ts_m TIMESTAMP, ts_s TIMESTAMP, lon DOUBLE PRECISION, lat DOUBLE PRECISION, speed FLOAT, direction FLOAT, state INT, event INT)".format(table_name) cur.execute(sql) conn.commit() files_list = filterByinteral(time_interal, csv_save_dir) # 插入数据 print('IMPORT FILES......')# copy_from 不支持 GEOMETRY对象批量导入 header_name_list = ['serial_number', 'code_company', 'unit_id', 'ts_m', 'ts_s', 'lon', 'lat', 'speed', 'direction', 'state', 'event'] dtype_dic = {'serial_number': object, 'code_company': object, 'unit_id': object, 'ts_m': object, 'ts_s': object, 'lon': object, 'lat': object, 'speed': float, 'direction': float, 'state': object, 'event': object} id_list = [] for csv in files_list: print(csv) csv_path = os.path.join(csv_save_dir, csv) data = pd.read_csv(csv_path, header=None, names=header_name_list, dtype=dtype_dic) u_id = np.unique(data.unit_id.values).tolist() id_list += u_id # dataframe类型转换为IO缓冲区中的str类型 output = StringIO() data.to_csv(output, sep='\t', index=False, header=False) output = output.getvalue() # print(output) cur.copy_from(StringIO(output), table_name) conn.commit() id_list = list(set(id_list)) print('IMPORT FILES OK!!')# copy_from 不支持 GEOMETRY对象批量导入 print('ADD A GEOMETRY COLUMN......') # ADD A GEOMETRY COLUMN cur.execute("alter table " + table_name + " add the_geom GEOMETRY") conn.commit() print('UPDATE THE GEOMETRY.....') # UPDATE THE GEOMETRY # for id in id_list: # sql = "UPDATE " + table_name + " set the_geom=st_geomfromtext(\'POINT(\'|| lon ||' '|| lat ||\')\',\'"+the_geom_SRID+"\') where unit_id = {0}".format(id) # sql = "UPDATE public.{0} SET the_geom=st_geomfromtext(\'POINT(\' || lon || \' \' || lat || \')\', {1})".format(table_name, the_geom_SRID) # cur.execute(sql) # conn.commit() sql = "UPDATE public.{0} SET the_geom=st_geomfromtext(\'POINT(\' || lon || \' \' || lat || \')\', {1})".format(table_name, the_geom_SRID) cur.execute(sql) conn.commit() cur.close() conn.close() print('done')
个人学习记录