Files
DeepStock/analyzer/StockPriceDirectionAnalyzer.py

243 lines
8.4 KiB
Python

import os
import sqlite3
import pandas as pd
class StockPriceDirectionAnalyzer:
stock_info = None
def __init__(self, stockFileName):
self.stock_info = {}
self.stockFileName = stockFileName
return
def open(self):
self.conn = sqlite3.connect(self.stockFileName)
self.cursor = self.conn.cursor()
def close(self):
self.cursor.close()
self.conn.close()
return
def get_all_stock_code(self):
self.open()
sql = "SELECT distinct code, name FROM stock"
self.cursor.execute(sql)
rows = self.cursor.fetchall()
for row in rows:
self.stock_info[row[0]] = row[1]
self.close()
return [row[0] for row in rows]
def load(self, master_code, limit_count=2000):
self.open()
sql = "SELECT ymd, close FROM stock where CODE=? order by ymd"
self.cursor.execute(sql, (master_code, ))
master_data = self.cursor.fetchall()
self.close()
master_data = master_data[-limit_count:]
return master_data
def getClosePrice(self, stock_codes, first_day, limit_count=2000):
data = {}
self.open()
# 신용잔고 누적
sql = "SELECT ymd, dep2_1 FROM meta_3 where ymd>=? order by ymd"
self.cursor.execute(sql, (first_day, ))
rows = self.cursor.fetchall()
if len(rows) >=limit_count:
#data['dep2_1'] = [(rows[i][0], rows[i][1]-rows[i-1][1]) for i in range(1, len(rows))]
data['dep2_1'] = [(rows[i][0], rows[i][1]) for i in range(len(rows))]
data['dep2_1_diff'] = [(rows[i][0], rows[i][1] - rows[i - 1][1]) for i in range(1, len(rows))]
self.stock_info['dep2_1'] = '신용잔고 누적'
self.stock_info['dep2_1_diff'] = '신용잔고 누적 차이'
# 투자자별 매매 동향
sql = "SELECT ymd, pri, fori, ins, ins0, ins1, ins2, ins3, ins4, ins5, cor FROM meta_2 where ymd>=? order by ymd"
self.cursor.execute(sql, (first_day,))
rows = self.cursor.fetchall()
if len(rows) >= limit_count:
data['pri'] = [(rows[i][0], rows[i][1]) for i in range(len(rows))]
data['fori'] = [(rows[i][0], rows[i][2]) for i in range(len(rows))]
data['ins'] = [(rows[i][0], rows[i][3]) for i in range(len(rows))]
data['ins0'] = [(rows[i][0], rows[i][4]) for i in range(len(rows))]
data['ins1'] = [(rows[i][0], rows[i][5]) for i in range(len(rows))]
data['ins2'] = [(rows[i][0], rows[i][6]) for i in range(len(rows))]
data['ins3'] = [(rows[i][0], rows[i][7]) for i in range(len(rows))]
data['ins4'] = [(rows[i][0], rows[i][8]) for i in range(len(rows))]
data['ins5'] = [(rows[i][0], rows[i][9]) for i in range(len(rows))]
data['cor'] = [(rows[i][0], rows[i][10]) for i in range(len(rows))]
self.stock_info['pri'] = '개인'
self.stock_info['fori'] = '외국인'
self.stock_info['ins'] = '기관합'
self.stock_info['ins0'] = '금융투자'
self.stock_info['ins1'] = '보험'
self.stock_info['ins2'] = '투신 (사모)'
self.stock_info['ins3'] = '은행'
self.stock_info['ins4'] = '기타금융기관'
self.stock_info['ins5'] = '연기금 등'
self.stock_info['cor'] = '기타법인'
# 환율
sql = "SELECT distinct code FROM meta_1"
self.cursor.execute(sql)
rows = self.cursor.fetchall()
exchange_codes = [row[0] for row in rows]
for exchange_code in exchange_codes:
sql = "SELECT ymd, price FROM meta_1 where code=? and ymd>=? order by ymd"
self.cursor.execute(sql, (exchange_code, first_day, ))
rows = self.cursor.fetchall()
if len(rows) >= limit_count:
data[exchange_code] = rows
self.stock_info[exchange_code] = exchange_code
# 원자재
sql = "SELECT distinct code FROM meta_5"
self.cursor.execute(sql)
rows = self.cursor.fetchall()
meterial_codes = [row[0] for row in rows]
for meterial_code in meterial_codes:
sql = "SELECT ymd, close FROM meta_5 where code=? and ymd>=? order by ymd"
self.cursor.execute(sql, (meterial_code, first_day, ))
rows = self.cursor.fetchall()
if len(rows) >= limit_count:
data[meterial_code] = rows
self.stock_info[meterial_code] = meterial_code
# 종목 종가
for stock_code in stock_codes:
sql = "SELECT ymd, close FROM stock where CODE=? and ymd>=? order by ymd"
self.cursor.execute(sql, (stock_code, first_day, ))
rows = self.cursor.fetchall()
if len(rows) >=limit_count:
data[stock_code] = rows
self.close()
return data
def debug(self, master_data, data):
master_days = [item[0] for item in master_data]
trimedData = {}
for i, stock_code in enumerate(data):
stock_data = data[stock_code]
stock_days = [item[0] for item in stock_data]
if len(master_days) < len(stock_days):
diff = set(stock_days) - set(master_days)
else:
diff = set(master_days) - set(stock_days)
tmp = []
for item in stock_data:
if item[0] not in diff:
tmp.append((item[0], item[1]))
if len(tmp) == len(master_data):
trimedData[stock_code] = tmp
return trimedData
def trim(self, master_data, data):
master_days = [item[0] for item in master_data]
trimedData = {}
for i, stock_code in enumerate(data):
stock_data = data[stock_code]
stock_days = [item[0] for item in stock_data]
intersection = set(stock_days) & set(master_days)
tmp = []
for item in stock_data:
if item[0] in intersection:
tmp.append((item[0], item[1]))
trimedData[stock_code] = tmp
return trimedData
def analyzeCorRelation(self, master_data, trimedData):
corr_scores = {}
master_days = [item[0] for item in master_data]
master_set = {}
for item in master_data:
master_set[item[0]] = item[1]
for stock_code in trimedData:
stock_data = trimedData[stock_code]
stock_days = [item[0] for item in stock_data]
stock_set = {}
for item in stock_data:
stock_set[item[0]] = item[1]
intersection = sorted(list(set(stock_days) & set(master_days)))
master_list = []
stock_list = []
for day in intersection:
master_list.append(master_set[day])
stock_list.append(stock_set[day])
lst = [master_list[3:], stock_list[:-3]]
df = pd.DataFrame(lst).T
corr = df.corr(method='pearson')
corr_scores[stock_code + "_" + self.stock_info[stock_code]] = corr.at[0, 1]
return corr_scores
def analyze(self, master_code, limit_count=1501):
stock_codes = self.get_all_stock_code()
master_data = self.load(master_code, limit_count)
first_day = master_data[0][0]
data = self.getClosePrice(stock_codes, first_day, limit_count)
#debugData = self.debug(master_data, data)
trimedData = self.trim(master_data, data)
corr_scores = self.analyzeCorRelation(master_data, trimedData)
return corr_scores
if __name__ == "__main__":
PROJECT_HOME = "."
RESOURCE_PATH = os.path.join(PROJECT_HOME, 'resources')
stockFileName = os.path.join(RESOURCE_PATH, 'stock.db')
stockPriceDirectionAnalyzer = StockPriceDirectionAnalyzer(stockFileName)
"""
corr_scores = stockPriceDirectionAnalyzer.analyze(master_code='122630')
corr_scores_list = sorted(corr_scores.items(), key=lambda item: item[1], reverse=True)
for item in corr_scores_list:
if item[1] > 0.8:
print("%s %4.3f" % (item[0], item[1]))
"""
corr_scores = stockPriceDirectionAnalyzer.analyze(master_code='252670')
corr_scores_list = sorted(corr_scores.items(), key=lambda item: item[1], reverse=True)
for item in corr_scores_list:
print("%s %4.3f" % (item[0], item[1]))
print('done...')