常用的Python代码片段(通用)

统一记录接口,logging库

import os
import logging
def init_log(output_dir):
    logging.basicConfig(level=logging.DEBUG,
                        format='%(asctime)s %(message)s',
                        datefmt='%Y%m%d-%H:%M:%S',
                        filename=os.path.join(output_dir, 'log.log'),
                        filemode='w')
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    logging.getLogger('').addHandler(console)
    return logging
	
# 使用
logging = init_log(SAVE_DIR)
logging.info('| epoch:{} | train | loss: {:.3f} | correct: {} | total: {} | acc: {:.3f} |'.format(
            epoch, loss, train_correct, train_total, train_acc
        ))

多类别分类任务中,统计各类别的个数

import numpy as np
def statics(label_arr):
    cls = np.unique(label_arr)
    print('*******Dataset*******')
    for c in list(cls):
        count = np.sum([int(x==c) for x in label_arr ])
        print('| 类别',c,'| 个数',count)
    print('| 总数', label_arr.shape[0])
    print('*' * 20)

将两列数据组合成一列,map函数将该列映射成str格式,以便相加

df['data'] = df['year'].map(str)+'-'+df['month'].map(str)

将Python里的字典dict保存成csv:

df = pd.DataFrame.from_dict(csv_dict, orient='columns')
df.to_csv('/home/lxg/data/animal_data/Result.csv')

递归Merge数据表

df = functools.reduce(
    lambda left, right: pd.merge(left, right, how='left',on=['id','year']),
    [maps, pp,pp_doy_rainDayCounts,pp_moy_rainZscore,modis_temp,pop,]
)    

深复制

Import copy
copy.deepcopy(init_maps)

多进程

from concurrent.futures import ProcessPoolExecutor
length = len(filelist)
files_iter = iter(filelist)
processes = 1
parallelism = 1
with ProcessPoolExecutor() as executor:
    chunksize = int(max(length / (processes * parallelism), 1))
    executor.map(tf2pth_oneitem, files_iter, chunksize=chunksize)

读取tfreocrd的内容和键值

tf2(推荐)

raw_dataset = tf.data.TFRecordDataset(input_file)
for i,raw_record in enumerate(raw_dataset) :
    example = tf.train.Example()
    example.ParseFromString(raw_record.numpy())
    keys = list((example.features.feature))

tf.compat.v1

for example in tf.compat.v1.python_io.tf_record_iterator("data/foobar.tfrecord"):
    print(tf.train.Example.FromString(example))
    keys = list(example.features.feature.keys())
    print(keys())

Python多进程实现

通过for循环批量处理数据是很常见的需求,但是当处理流程复杂时,一个一个按顺序去处理会十分缓慢。那么由于Python的多线程一直因为GIL锁的原因不太好用。因此多进程是前辈都推荐的一个办法,但是网上的多进程实现多如牛毛,遵循奥卡姆剃刀原则,这里介绍一个代码少又好用的多进程运行代码的办法:

正常来讲,我们按顺序处理代码的逻辑如下:

for i in filelist:
    worker(i)

接下来,我们仅需要引入ProcessPoolExecutor库,并使用Map函数,进行多进程提取:

from concurrent.futures import ProcessPoolExecutor
with ProcessPoolExecutor() as executor:
    executor.map(worker, iter(filelist), chunksize=4)

ProcessPoolExecutor将可迭代对象分割成许多块,并将其作为单独的任务提交给池。这些块的(近似)大小可以通过将chunksize设置为正整数来指定。默认大小是1,使用较大的chunksize值可以显著提高性能,同时也占用更多资源。

需要注意的是,在使用上述多进程方式时,如果worker函数内部有bug,那么程序会直接终结,编译器不会报任何错误,我一般的解决办法是,1.在进行多进程之前,在for循环中跑一下worker()函数,反正就一行代码的问题,很好改;2. 用print或logging的方式查看进度,看输出的结果是否正确。

density_scatter

原因是通过程序绘制的图更规范,有一些指标,如散点图的拟合公式、\(R^2\)等 ,以前在excel中需要点点点才能显示,在观察、分析大量log文件时,很繁琐,不利于发现规律。
而在程序写好之后,log可以很直观的看到,并进行对比。并且由于程序制图较规范,便于不同实验之间进行比较,美观度也还可以,后续略作调整即可用于论文插图
以下代码实现了plt的density scatter.

def density_scatter( x , y, ax = None, sort = True, bins = 20, is_cbar=True, **kwargs )   :
    """
    Scatter plot colored by 2d histogram
    """
    if ax is None :
        fig , ax = plt.subplots()
    data , x_e, y_e = np.histogram2d( x, y, bins = bins, density = True )
    z = interpn( ( 0.5*(x_e[1:] + x_e[:-1]) , 0.5*(y_e[1:]+y_e[:-1]) ) , data , np.vstack([x,y]).T , method = "splinef2d", bounds_error = False)
    #To be sure to plot all data
    z[np.where(np.isnan(z))] = 0.0
    # Sort the points by density, so that the densest points are plotted last
    if sort :
        idx = z.argsort()
        x, y, z = x[idx], y[idx], z[idx]
    ax.scatter( x, y, c=z, **kwargs )
    ax.grid(linestyle='--',linewidth=0.5)
    norm = Normalize(vmin = np.min(z), vmax = np.max(z))
    if is_cbar:
        cbar = plt.colorbar(cm.ScalarMappable(norm = norm), ax=ax)
        cbar.ax.set_ylabel('Density')
    return ax

手工画进度条

#!/usr/bin/python
# -*- coding: UTF-8 -*-
import sys,time
for i in range(30): #进度条类型
    sys.stdout.write("*")
    sys.stdout.flush()
    time.sleep(0.2)
posted @ 2023-11-15 19:11  GeoAi  阅读(24)  评论(0编辑  收藏  举报