PyAlgoTrade之Feed

422 阅读3分钟

Feeds 在PyAlogTrade中是作为数据供给者存在的。比如如果想要用csv中的数据,就需要使用CsvFeed。 在PyAlogTrade中默认做了如下几种实现:

image.png 那么Feed倒是是用来干什么的? 我们可以具体可以看下BarFeed的实现:

class BarFeed(barfeed.BaseBarFeed):  
    def __init__(self, frequency, maxLen=None):  
        super(BarFeed, self).__init__(frequency, maxLen)  
  
        self.__bars = {}  
        self.__nextPos = {}  
        self.__started = False  
		self.__currDateTime = None

BarFeed的定义非常简单, 其最核心的两个元素是 bars 和 nextPos。其中bars中保存当前的支持,而nextpos 是用来迭代时候寻找下一个值所在。 所以 BarFeed的实现是一个典型的以时间作为循环的单项链表。

以官方的文档为例子,我们看看其内部发生了什么?

from pyalgotrade import strategy
from pyalgotrade.barfeed import quandlfeed

class MyStrategy(strategy.BacktestingStrategy):
    def __init__(self, feed, instrument):
        super(MyStrategy, self).__init__(feed)
        self.__instrument = instrument

    def onBars(self, bars):
        bar = bars[self.__instrument]
        self.info(bar.getClose())

# Load the bar feed from the CSV file
feed = quandlfeed.Feed()
feed.addBarsFromCSV("orcl", "WIKI-ORCL-2000-quandl.csv")

# Evaluate the strategy with the feed's bars.
myStrategy = MyStrategy(feed, "orcl")
myStrategy.run()

这里我们最关心的是如何把csv的数据,转为为FeedBar的。 我们去最核心的关键代码:

class BarFeed(membf.BarFeed):
	def addBarsFromCSV(self, instrument, path, rowParser, skipMalformedBars=False):  
		....
    # Load the csv file  
	loadedBars = []  
	reader = csvutils.FastDictReader(open(path, "r"), fieldnames=rowParser.getFieldNames(), delimiter=rowParser.getDelimiter())  
	for row in reader:  
	    bar_ = parse_bar(row)  
	    if bar_ is not None and (self.__barFilter is None or self.__barFilter.includeBar(bar_)):  
	    loadedBars.append(bar_)  
	  
	self.addBarsFromSequence(instrument, loadedBars)

没什么特殊的,从CSV取数据,然后addBarsFromSequence, 然后我们再往下跟:

class BarFeed(barfeed.BaseBarFeed):
	def addBarsFromSequence(self, instrument, bars):  
	    if self.__started:  
	        raise Exception("Can't add more bars once you started consuming bars")  
	  
	    self.__bars.setdefault(instrument, [])  
	    self.__nextPos.setdefault(instrument, 0)  
	  
	    # Add and sort the bars  
	 self.__bars[instrument].extend(bars)  
	    self.__bars[instrument].sort(key=lambda b: b.getDateTime())  
	  
	    self.registerInstrument(instrument)

这里就是最核心的关键代码所在了, 把读出来的数据放入 self.__bars 中,然后以时间升序排序。

自定义panda的 feed

panda是最常用的库, 很多时候需要用它做数据的操作, 这时候就需要写一个PanadaFeed , 自定义一个Feed最核心的是什么? 关键是几个点:

  1. class BarFeed(membf.BarFeed)
  2. BarFeed:addBarsFromSequence

那么又面临一个问题, addBarsFromSequence(self, instrument, bars): 参数中的bars是有格式要求的,格式是什么样子的, 这里我们就可以借助于 csvfeed.GenericRowParser 做格式解析和校验:

class GenericRowParser(RowParser):  
    def __init__(self, columnNames, dateTimeFormat, dailyBarTime, frequency, timezone, barClass=bar.BasicBar):  
        self.__dateTimeFormat = dateTimeFormat  
        self.__dailyBarTime = dailyBarTime  
        self.__frequency = frequency  
        self.__timezone = timezone  
        self.__haveAdjClose = False  
		 self.__barClass = barClass  
        # Column names.  
		 self.__dateTimeColName = columnNames["datetime"]  
        self.__openColName = columnNames["open"]  
        self.__highColName = columnNames["high"]  
        self.__lowColName = columnNames["low"]  
        self.__closeColName = columnNames["close"]  
        self.__volumeColName = columnNames["volume"]  
        self.__adjCloseColName = columnNames["adj_close"]  
        self.__columnNames = columnNames

	def parseBar(self, csvRowDict):  
	    dateTime = self._parseDate(csvRowDict[self.__dateTimeColName])  
	    open_ = float(csvRowDict[self.__openColName])  
	    high = float(csvRowDict[self.__highColName])  
	    low = float(csvRowDict[self.__lowColName])  
	    close = float(csvRowDict[self.__closeColName])  
	    volume = float(csvRowDict[self.__volumeColName])  
	    adjClose = None  
	 if self.__adjCloseColName is not None:  
	        adjCloseValue = csvRowDict.get(self.__adjCloseColName, "")  
	        if len(adjCloseValue) > 0:  
	            adjClose = float(adjCloseValue)  
	            self.__haveAdjClose = True  
	  
	 # Process extra columns.  
	 extra = {}  
	    for k, v in six.iteritems(csvRowDict):  
	        if k not in self.__columnNames.values():  
	            extra[k] = csvutils.float_or_string(v)  
	  
	    return self.__barClass(  
	        dateTime, open_, high, low, close, volume, adjClose, self.__frequency, extra=extra  
    )
		

其返回的就是这个BasicBar

class BasicBar(Bar):  
    # Optimization to reduce memory footprint.  
 __slots__ = (  
        '__dateTime',  
 '__open',  
 '__close',  
 '__high',  
 '__low',  
 '__volume',  
 '__adjClose',  
 '__frequency',  
 '__useAdjustedValue',  
 '__extra',  
 )

这样我们就可以看下我们的核心伪代码是什么样子的:

#使用 GenericeRowParser作为解析器
	rowParser = csvfeed.GenericRowParser(  
    self.__columnNames,  
 self.__dateTimeFormat,  
 self.getDailyBarTime(),  
 self.getFrequency(),  
 timezone,  
 self.__barClass  
)

# 把dataframe转为为dict
list_of_dicts = df.fillna('').astype(str).to_dict('records')


# 伪代码
bars =  for index in list_of_dicts.keys():rowParser
# 调用 
self.addBarsFromSequence(instrument, loadedBars)

总结

PyAlgoTrade是以Bar作为最基本的存续结构,其是一个以dateTime作为索引,有open, close, low, volume等字段的结构。Feed可以认为是一个解析器,从其他的数据存储格式中转换到 BarFeed存储,其存储结构其实是一个单向链表,以__nextPos作为索引字段。