alex的ATM学习笔记
这是理解和学习ALex的简单ATM https://github.com/triaquae/py3_training/tree/master/atm
day5-atm/ ├── README ├── atm #ATM主程目录 │ ├── __init__.py │ ├── bin #ATM 执行文件 目录 │ │ ├── __init__.py │ │ ├── atm.py #ATM 执行程序 │ │ └── manage.py #ATM 管理端,未实现 │ ├── conf #配置文件 │ │ ├── __init__.py │ │ └── settings.py │ ├── core #主要程序逻辑都 在这个目录 里 │ │ ├── __init__.py │ │ ├── accounts.py #用于从文件里加载和存储账户数据 │ │ ├── auth.py #用户认证模块 │ │ ├── db_handler.py #数据库连接引擎 │ │ ├── logger.py #日志记录模块 │ │ ├── main.py #主逻辑交互程序 │ │ └── transaction.py #记账\还钱\取钱等所有的与账户金额相关的操作都 在这 │ ├── db #用户数据存储的地方 │ │ ├── __init__.py │ │ ├── account_sample.py #生成一个初始的账户数据 ,把这个数据 存成一个 以这个账户id为文件名的文件,放在accounts目录 就行了,程序自己去会这里找 │ │ └── accounts #存各个用户的账户数据 ,一个用户一个文件 │ │ └── 1234.json #一个用户账户示例文件 │ └── log #日志目录 │ ├── __init__.py │ ├── access.log #用户访问和操作的相关日志 │ └── transactions.log #所有的交易日志 └── shopping_mall #电子商城程序,需单独实现 └── __init__.py
#这是atm.py文件
import os import sys # base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 一个上上级的绝对路径并赋给变量 # # sys.path.append(base_dir) 把路径添加到调用模块时的查找路径 sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) #这是上面两句和在一起 from core import main #通过添加路径,我们可以调用ATM目录下的core包的main模块 if __name__ == '__main__': # __name__ 这个函数保证只有在当前文件运行下面的内容 main.run() # 调用main模块下的run()函数
#这是main.py文件
from core import auth
from core import accounts
from core import logger
from core import accounts
from core import transaction
from core import db_handler
from core.auth import login_required
import time
import json
#transaction logger
# trans_logger = logger.logger('transaction') #一个打印transaction的日志生成器
# #access logger
# access_logger = logger.logger('access') #一个打印access的日志生成器
trc_logger = logger.trc_logger
acc_logger = logger.acc_logger
#temp account data ,only saves the data in memory
user_data = {
'account_id':None,
'is_authenticated':False,
'account_data':None
}
def account_info(acc_data):
print(user_data)
@login_required
def repay(acc_data): #acc_data是一个字典
'''
print current balance and let user repay the bill
:return:
'''
account_data = accounts.load_current_balance(acc_data['account_id']) #获得存储的用户信息文档
account_data = db_handler.write(account_data) #转换为字典
#for k,v in account_data.items():
# print(k,v )
current_balance= ''' --------- BALANCE INFO --------
Credit : %s
Balance: %s''' %(account_data['credit'],account_data['balance']) #打印出信用卡额度和现金
print(current_balance)
back_flag = False
while not back_flag:
repay_amount = input("\033[33;1mInput repay amount:\033[0m").strip() #输入还款金额
if len(repay_amount) >0 and repay_amount.isdigit():
print('还款')
new_balance = transaction.make_transaction(trc_logger,account_data,'repay', repay_amount) #
if new_balance:
print('''\033[42;1mNew Balance:%s\033[0m''' %(new_balance['balance']))
else:
print('\033[31;1m[%s] is not a valid amount, only accept integer!\033[0m' % repay_amount)
if repay_amount == 'b':
back_flag = True
def withdraw(acc_data):
'''
print current balance and let user do the withdraw action
:param acc_data:
:return:
'''
account_data = accounts.load_current_balance(acc_data['account_id'])
account_data = db_handler.write(account_data)
current_balance= ''' --------- BALANCE INFO --------
Credit : %s
Balance: %s''' %(account_data['credit'],account_data['balance'])
print(current_balance)
back_flag = False
while not back_flag:
withdraw_amount = input("\033[33;1mInput withdraw amount:\033[0m").strip()
if len(withdraw_amount) >0 and withdraw_amount.isdigit():
new_balance = transaction.make_transaction(trc_logger,account_data,'withdraw', withdraw_amount)
if new_balance:
print('''\033[42;1mNew Balance:%s\033[0m''' %(new_balance['balance']))
else:
print('\033[31;1m[%s] is not a valid amount, only accept integer!\033[0m' % withdraw_amount)
if withdraw_amount == 'b':
back_flag = True
def transfer(acc_data):
pass
def pay_check(acc_data):
pass
def logout(acc_data):
pass
def interactive(acc_data):
'''
interact with user
:return:
'''
menu = u'''
------- Oldboy Bank ---------
\033[32;1m1. 账户信息
2. 还款(功能已实现)
3. 取款(功能已实现)
4. 转账
5. 账单
6. 退出
\033[0m'''
menu_dic = {
'1': account_info,
'2': repay,
'3': withdraw,
'4': transfer,
'5': pay_check,
'6': logout,
}
exit_flag = False
while not exit_flag:
print(menu)
user_option = input(">>:").strip()
if user_option in menu_dic:
print('accdata',acc_data)
#acc_data['is_authenticated'] = False
menu_dic[user_option](acc_data)
else:
print("\033[31;1mOption does not exist!\033[0m")
def run():
'''
this function will be called right a way when the program started, here handles the user interaction stuff
:return:
'''
acc_data = auth.acc_login(user_data,acc_logger) #赋值并执行auth模块下的认证函数,把用户数据字典(空)和日志器传到这个函数,执行成功会返回一个储存的用户字典
if user_data['is_authenticated']: #如果认证成功(开始为FALSE,认证成功后为TRUE)
user_data['account_data'] = acc_data #ud字典的ad键内容就嵌套一个用户字典
interactive(user_data) #操作信用卡的函数
#这是auth.py文件
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from core import db_handler
from conf import settings
from core import logger
import json
import time
def login_required(func):
"验证用户是否登录"
def wrapper(*args,**kwargs):
#print('--wrapper--->',args,kwargs)
if args[0].get('is_authenticated'):
return func(*args,**kwargs)
else:
exit("User is not authenticated.")
return wrapper
def acc_auth(account,password):
'''
account auth func
:param account: credit account number
:param password: credit card password
:return: if passed the authentication , retun the account object, otherwise ,return None
'''
db_path = db_handler.db_handler(settings.DATABASE) #拿到一个存储账户信息文件的路径,使用函数是为了扩展可能的方法(从数据库)
account_file = "%s/%s.json" %(db_path,account) #从上面的路径获取存储某用户信息的文件
# print(account_file)
if os.path.isfile(account_file): #os.path.isfile方法判断是否存括号内的文件
with open(account_file,'r') as f: #确定存在后,以读模式打开
account_data = json.load(f) #用json模块的load方法读取内容
if account_data['password'] == password: #如果密码正确
exp_time_stamp = time.mktime(time.strptime(account_data['expire_date'], "%Y-%m-%d")) #从用户信息字典找到过期时间转换成秒
if time.time() >exp_time_stamp: #当前的秒如果大于过期秒
print("\033[31;1mAccount [%s] has expired,please contact the back to get a new card!\033[0m" % account) #提示过期
else: #passed the authentication
return account_data #没过期就返回用户信息字典
else:
return False
print("\033[31;1mAccount ID or password is incorrect!\033[0m") #如果密码不正确的提示
else:
return False
print("\033[31;1mAccount [%s] does not exist! FLAG\033[0m" % account) #如果不存在文件的提示
# logger = logger.acc_logger
if __name__ == '__main__': logger.debug('welcome') def acc_auth2(account,password): ''' 优化版认证接口 :param account: credit account number :param password: credit card password :return: if passed the authentication , retun the account object, otherwise ,return None ''' db_api = db_handler.db_handler() data = db_api("select * from accounts where account=%s" % account) if data['password'] == password: exp_time_stamp = time.mktime(time.strptime(data['expire_date'], "%Y-%m-%d")) if time.time() > exp_time_stamp: print("\033[31;1mAccount [%s] has expired,please contact the back to get a new card!\033[0m" % account) else: # passed the authentication return data else: print("\033[31;1mAccount ID or password is incorrect!\033[0m") def acc_login(user_data,log_obj): ''' account login func :user_data: user info data , only saves in memory :return: ''' retry_count = 0 #初始次数 while user_data['is_authenticated'] is not True and retry_count < 3 : #如果用户认证属性为Fales且初始次数小于3 account = input("\033[32;1maccount:\033[0m").strip() #输入账号 password = input("\033[32;1mpassword:\033[0m").strip() #输入密码 auth = acc_auth(account, password) #把输入的账号密码传给认证函数去确认是否正确 正确就返回了一个用户字典 if auth: #not None means passed the authentication #认证成功 user_data['is_authenticated'] = True #修改认证属性为True user_data['account_id'] = account #修改用户账号为认证成功的账号 logger.acc_logger.debug('welcome') return auth #返回这个字典 retry_count +=1 else: log_obj.error("account [%s] too many login attempts" % account) exit()
#这是db_handler.py文件
''' handle all the database interactions ''' import json,time ,os from conf import settings def file_db_handle(conn_params): ''' parse the db file path :param conn_params: the db connection params set in settings :return: ''' # print('file db:',conn_params) #打印DATABASE字典,无用 db_path ='%s/%s' %(conn_params['path'],conn_params['name']) #一个ATM/db/accounts/的路径,即储存账户信息的路径 return db_path #返回这个路径 def db_handler(conn_params): ''' connect to db :param conn_parms: the db connection params set in settings :return:a ''' # conn_params = settings.DATABASE if conn_params['engine'] == 'file_storage': #意思是如果储存方式为文件储存, return file_db_handle(conn_params) #那么返回一个从文件读取数据的函数 # elif conn_params['engine'] == 'mysql': # 这是扩展为数据库的方法....提供思路 无用 # # pass #todo def write(file): #这是我添加的函数,因为会报错 f = open(file,'r') account_data = json.load(f) return account_data def file_execute(sql,**kwargs): conn_params = settings.DATABASE db_path = '%s/%s' % (conn_params['path'], conn_params['name']) print(sql,db_path) sql_list = sql.split("where") print(sql_list) if sql_list[0].startswith("select") and len(sql_list)> 1:#has where clause column,val = sql_list[1].strip().split("=") if column == 'account': account_file = "%s/%s.json" % (db_path, val) print(account_file) if os.path.isfile(account_file): with open(account_file, 'r') as f: account_data = json.load(f) return account_data else: exit("\033[31;1mAccount [%s] does not exist!\033[0m" % val ) elif sql_list[0].startswith("update") and len(sql_list)> 1:#has where clause column, val = sql_list[1].strip().split("=") if column == 'account': account_file = "%s/%s.json" % (db_path, val) #print(account_file) if os.path.isfile(account_file): account_data = kwargs.get("account_data") with open(account_file, 'w') as f: acc_data = json.dump(account_data, f) return True
#这是accounts.py文件
import json import time from core import db_handler from conf import settings def load_current_balance(account_id): ''' return account balance and other basic info :param account_id: :return: ''' db_path = db_handler.db_handler(settings.DATABASE) #获得存储用户信息文件夹的路径 account_file = "%s/%s.json" %(db_path,account_id) #获取存储用户信息的文件 # # db_api = db_handler.db_handler() # data = db_api("select * from accounts where account=%s" % account_id) return account_file # with open(account_file) as f: # acc_data = json.load(f) # return acc_data def dump_account(account_data): ''' after updated transaction or account data , dump it back to file db :param account_data: :return: ''' # db_api = db_handler.db_handler() # data = db_api("update accounts where account=%s" % account_data['id'],account_data=account_data) db_path = db_handler.db_handler(settings.DATABASE) account_file = "%s/%s.json" %(db_path,account_data['id']) with open(account_file, 'w') as f: acc_data = json.dump(account_data,f) return True
#这是transaction.py文件
def make_transaction(log_obj,account_data,tran_type,amount,**others): #日志,账户字典,操作类型,金额,其他 ''' deal all the user transactions :param account_data: user account data :param tran_type: transaction type :param amount: transaction amount :param others: mainly for logging usage :return: ''' amount = float(amount) if tran_type in settings.TRANSACTION_TYPE: interest = amount * settings.TRANSACTION_TYPE[tran_type]['interest'] #计算利息 old_balance = account_data['balance'] #之前的现金 if settings.TRANSACTION_TYPE[tran_type]['action'] == 'plus': # new_balance = old_balance + amount + interest # elif settings.TRANSACTION_TYPE[tran_type]['action'] == 'minus': new_balance = old_balance - amount - interest #check credit if new_balance <0: print('''\033[31;1mYour credit [%s] is not enough for this transaction [-%s], your current balance is [%s]''' %(account_data['credit'],(amount + interest), old_balance )) return account_data['balance'] = new_balance accounts.dump_account(account_data) #save the new balance back to file log_obj.info("account:%s action:%s amount:%s interest:%s" % (account_data['id'], tran_type, amount,interest) ) return account_data else: print("\033[31;1mTransaction type [%s] is not exist!\033[0m" % tran_type)
#这是logger.py文件
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import logging
from conf import settings
def logger(log_type):
#create logger
logger = logging.getLogger(log_type)
logger.setLevel(settings.LOG_LEVEL)
# create console handler and set level to debug
ch = logging.StreamHandler()
ch.setLevel(settings.LOG_LEVEL)
# create file handler and set level to warning
log_file = "%s/log/%s" %(settings.BASE_DIR, settings.LOG_TYPES[log_type])
fh = logging.FileHandler(log_file)
fh.setLevel(settings.LOG_LEVEL)
# create formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# add formatter to ch and fh
ch.setFormatter(formatter)
fh.setFormatter(formatter)
# add ch and fh to logger
logger.addHandler(ch)
logger.addHandler(fh)
return logger
# 'application' code
'''logger.debug('debug message')
logger.info('info message')
logger.warn('warn message')
logger.error('error message')
logger.critical('critical message')'''
acc_logger = logger('access')
trc_logger = logger('transaction')
# acc_logger.debug('123')