Feeds 在PyAlogTrade中是作为数据供给者存在的。比如如果想要用csv中的数据,就需要使用CsvFeed。 在PyAlogTrade中默认做了如下几种实现:
那么Feed倒是是用来干什么的? 我们可以具体可以看下BarFeed的实现:
class BarFeed(barfeed.BaseBarFeed):
def __init__(self, frequency, maxLen=None):
super(BarFeed, self).__init__(frequency, maxLen)
self.__bars = {}
self.__nextPos = {}
self.__started = False
self.__currDateTime = None
BarFeed的定义非常简单, 其最核心的两个元素是 bars 和 nextPos。其中bars中保存当前的支持,而nextpos 是用来迭代时候寻找下一个值所在。 所以 BarFeed的实现是一个典型的以时间作为循环的单项链表。
以官方的文档为例子,我们看看其内部发生了什么?
from pyalgotrade import strategy
from pyalgotrade.barfeed import quandlfeed
class MyStrategy(strategy.BacktestingStrategy):
def __init__(self, feed, instrument):
super(MyStrategy, self).__init__(feed)
self.__instrument = instrument
def onBars(self, bars):
bar = bars[self.__instrument]
self.info(bar.getClose())
# Load the bar feed from the CSV file
feed = quandlfeed.Feed()
feed.addBarsFromCSV("orcl", "WIKI-ORCL-2000-quandl.csv")
# Evaluate the strategy with the feed's bars.
myStrategy = MyStrategy(feed, "orcl")
myStrategy.run()
这里我们最关心的是如何把csv的数据,转为为FeedBar的。 我们去最核心的关键代码:
class BarFeed(membf.BarFeed):
def addBarsFromCSV(self, instrument, path, rowParser, skipMalformedBars=False):
....
# Load the csv file
loadedBars = []
reader = csvutils.FastDictReader(open(path, "r"), fieldnames=rowParser.getFieldNames(), delimiter=rowParser.getDelimiter())
for row in reader:
bar_ = parse_bar(row)
if bar_ is not None and (self.__barFilter is None or self.__barFilter.includeBar(bar_)):
loadedBars.append(bar_)
self.addBarsFromSequence(instrument, loadedBars)
没什么特殊的,从CSV取数据,然后addBarsFromSequence, 然后我们再往下跟:
class BarFeed(barfeed.BaseBarFeed):
def addBarsFromSequence(self, instrument, bars):
if self.__started:
raise Exception("Can't add more bars once you started consuming bars")
self.__bars.setdefault(instrument, [])
self.__nextPos.setdefault(instrument, 0)
# Add and sort the bars
self.__bars[instrument].extend(bars)
self.__bars[instrument].sort(key=lambda b: b.getDateTime())
self.registerInstrument(instrument)
这里就是最核心的关键代码所在了, 把读出来的数据放入 self.__bars 中,然后以时间升序排序。
自定义panda的 feed
panda是最常用的库, 很多时候需要用它做数据的操作, 这时候就需要写一个PanadaFeed , 自定义一个Feed最核心的是什么? 关键是几个点:
- class BarFeed(membf.BarFeed)
- BarFeed:addBarsFromSequence
那么又面临一个问题, addBarsFromSequence(self, instrument, bars): 参数中的bars是有格式要求的,格式是什么样子的, 这里我们就可以借助于 csvfeed.GenericRowParser 做格式解析和校验:
class GenericRowParser(RowParser):
def __init__(self, columnNames, dateTimeFormat, dailyBarTime, frequency, timezone, barClass=bar.BasicBar):
self.__dateTimeFormat = dateTimeFormat
self.__dailyBarTime = dailyBarTime
self.__frequency = frequency
self.__timezone = timezone
self.__haveAdjClose = False
self.__barClass = barClass
# Column names.
self.__dateTimeColName = columnNames["datetime"]
self.__openColName = columnNames["open"]
self.__highColName = columnNames["high"]
self.__lowColName = columnNames["low"]
self.__closeColName = columnNames["close"]
self.__volumeColName = columnNames["volume"]
self.__adjCloseColName = columnNames["adj_close"]
self.__columnNames = columnNames
def parseBar(self, csvRowDict):
dateTime = self._parseDate(csvRowDict[self.__dateTimeColName])
open_ = float(csvRowDict[self.__openColName])
high = float(csvRowDict[self.__highColName])
low = float(csvRowDict[self.__lowColName])
close = float(csvRowDict[self.__closeColName])
volume = float(csvRowDict[self.__volumeColName])
adjClose = None
if self.__adjCloseColName is not None:
adjCloseValue = csvRowDict.get(self.__adjCloseColName, "")
if len(adjCloseValue) > 0:
adjClose = float(adjCloseValue)
self.__haveAdjClose = True
# Process extra columns.
extra = {}
for k, v in six.iteritems(csvRowDict):
if k not in self.__columnNames.values():
extra[k] = csvutils.float_or_string(v)
return self.__barClass(
dateTime, open_, high, low, close, volume, adjClose, self.__frequency, extra=extra
)
其返回的就是这个BasicBar
class BasicBar(Bar):
# Optimization to reduce memory footprint.
__slots__ = (
'__dateTime',
'__open',
'__close',
'__high',
'__low',
'__volume',
'__adjClose',
'__frequency',
'__useAdjustedValue',
'__extra',
)
这样我们就可以看下我们的核心伪代码是什么样子的:
#使用 GenericeRowParser作为解析器
rowParser = csvfeed.GenericRowParser(
self.__columnNames,
self.__dateTimeFormat,
self.getDailyBarTime(),
self.getFrequency(),
timezone,
self.__barClass
)
# 把dataframe转为为dict
list_of_dicts = df.fillna('').astype(str).to_dict('records')
# 伪代码
bars = for index in list_of_dicts.keys():rowParser
# 调用
self.addBarsFromSequence(instrument, loadedBars)
总结
PyAlgoTrade是以Bar作为最基本的存续结构,其是一个以dateTime作为索引,有open, close, low, volume等字段的结构。Feed可以认为是一个解析器,从其他的数据存储格式中转换到 BarFeed存储,其存储结构其实是一个单向链表,以__nextPos作为索引字段。